47 #ifndef MUI_MATRIX_ARITHMETIC_H_
48 #define MUI_MATRIX_ARITHMETIC_H_
61 template<
typename ITYPE,
typename VTYPE>
64 if (rows_ != addend.rows_ || cols_ != addend.cols_) {
65 std::cerr <<
"MUI Error [matrix_arithmetic.h]: matrix size mismatch during matrix addition" << std::endl;
69 if (addend.matrix_format_ != matrix_format_) {
73 if (addend.matrix_format_ == format::COO) {
74 addend.
sort_coo(
true,
true,
"overwrite");
75 }
else if (addend.matrix_format_ == format::CSR) {
77 }
else if (addend.matrix_format_ == format::CSC) {
80 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised addend matrix format for matrix operator+()" << std::endl;
81 std::cerr <<
" Please set the addend matrix_format_ as:" << std::endl;
82 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
83 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
84 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
90 if (!this->is_sorted_unique(
"matrix_arithmetic.h",
"operator+()")){
91 if (matrix_format_ == format::COO) {
92 this->sort_coo(
true,
true,
"overwrite");
93 }
else if (matrix_format_ == format::CSR) {
94 this->sort_csr(
true,
"overwrite");
95 }
else if (matrix_format_ == format::CSC) {
96 this->sort_csc(
true,
"overwrite");
98 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix operator+()" << std::endl;
99 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
100 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
101 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
102 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
110 if (matrix_format_ == format::COO) {
113 res.matrix_coo.values_.reserve(matrix_coo.values_.size() + addend.matrix_coo.values_.size());
114 res.matrix_coo.row_indices_.reserve(matrix_coo.row_indices_.size() + addend.matrix_coo.row_indices_.size());
115 res.matrix_coo.col_indices_.reserve(matrix_coo.col_indices_.size() + addend.matrix_coo.col_indices_.size());
118 res.matrix_coo.values_ = std::vector<VTYPE>(matrix_coo.values_.begin(), matrix_coo.values_.end());
119 res.matrix_coo.row_indices_ = std::vector<ITYPE>(matrix_coo.row_indices_.begin(), matrix_coo.row_indices_.end());
120 res.matrix_coo.col_indices_ = std::vector<ITYPE>(matrix_coo.col_indices_.begin(), matrix_coo.col_indices_.end());
123 res.matrix_coo.values_.insert(res.matrix_coo.values_.end(), addend.matrix_coo.values_.begin(), addend.matrix_coo.values_.end());
124 res.matrix_coo.row_indices_.insert(res.matrix_coo.row_indices_.end(), addend.matrix_coo.row_indices_.begin(), addend.matrix_coo.row_indices_.end());
125 res.matrix_coo.col_indices_.insert(res.matrix_coo.col_indices_.end(), addend.matrix_coo.col_indices_.begin(), addend.matrix_coo.col_indices_.end());
129 res.nnz_ = res.matrix_coo.values_.size();
131 }
else if (matrix_format_ == format::CSR) {
134 res.matrix_csr.values_.reserve(matrix_csr.values_.size() + addend.matrix_csr.values_.size());
135 res.matrix_csr.row_ptrs_.resize(rows_ + 1);
136 res.matrix_csr.col_indices_.reserve(matrix_csr.col_indices_.size() + addend.matrix_csr.col_indices_.size());
139 while (row < rows_) {
140 ITYPE start = matrix_csr.row_ptrs_[row];
141 ITYPE end = matrix_csr.row_ptrs_[row + 1];
143 ITYPE addend_start = addend.matrix_csr.row_ptrs_[row];
144 ITYPE addend_end = addend.matrix_csr.row_ptrs_[row + 1];
146 res.matrix_csr.row_ptrs_[0] = 0;
150 ITYPE j = addend_start;
151 while (i < end && j < addend_end) {
152 ITYPE col = matrix_csr.col_indices_[i];
153 ITYPE addend_col = addend.matrix_csr.col_indices_[j];
155 if (col == addend_col) {
158 res.matrix_csr.values_.emplace_back(matrix_csr.values_[i] + addend.matrix_csr.values_[j]);
159 res.matrix_csr.col_indices_.emplace_back(col);
163 }
else if (col < addend_col) {
166 res.matrix_csr.values_.emplace_back(matrix_csr.values_[i]);
167 res.matrix_csr.col_indices_.emplace_back(col);
173 res.matrix_csr.values_.emplace_back(addend.matrix_csr.values_[j]);
174 res.matrix_csr.col_indices_.emplace_back(addend_col);
181 for (; i < end; i++) {
183 res.matrix_csr.values_.emplace_back(matrix_csr.values_[i]);
184 res.matrix_csr.col_indices_.emplace_back(matrix_csr.col_indices_[i]);
189 for (; j < addend_end; j++) {
191 res.matrix_csr.values_.emplace_back(addend.matrix_csr.values_[j]);
192 res.matrix_csr.col_indices_.emplace_back(addend.matrix_csr.col_indices_[j]);
197 res.nnz_ = res.matrix_csr.col_indices_.size();
198 res.matrix_csr.row_ptrs_[row + 1] = res.nnz_;
203 }
else if (matrix_format_ == format::CSC) {
206 res.matrix_csc.values_.reserve(matrix_csc.values_.size() + addend.matrix_csc.values_.size());
207 res.matrix_csc.row_indices_.reserve(matrix_csc.row_indices_.size() + addend.matrix_csc.row_indices_.size());
208 res.matrix_csc.col_ptrs_.resize(cols_ + 1);
211 while (column < cols_) {
212 ITYPE start = matrix_csc.col_ptrs_[column];
213 ITYPE end = matrix_csc.col_ptrs_[column + 1];
215 ITYPE addend_start = addend.matrix_csc.col_ptrs_[column];
216 ITYPE addend_end = addend.matrix_csc.col_ptrs_[column + 1];
218 res.matrix_csc.col_ptrs_[0] = 0;
222 ITYPE j = addend_start;
223 while (i < end && j < addend_end) {
224 ITYPE row = matrix_csc.row_indices_[i];
225 ITYPE addend_row = addend.matrix_csc.row_indices_[j];
227 if (row == addend_row) {
230 res.matrix_csc.values_.emplace_back(matrix_csc.values_[i] + addend.matrix_csc.values_[j]);
231 res.matrix_csc.row_indices_.emplace_back(row);
235 }
else if (row < addend_row) {
238 res.matrix_csc.values_.emplace_back(matrix_csc.values_[i]);
239 res.matrix_csc.row_indices_.emplace_back(row);
245 res.matrix_csc.values_.emplace_back(addend.matrix_csc.values_[j]);
246 res.matrix_csc.row_indices_.emplace_back(addend_row);
253 for (; i < end; i++) {
255 res.matrix_csc.values_.emplace_back(matrix_csc.values_[i]);
256 res.matrix_csc.row_indices_.emplace_back(matrix_csc.row_indices_[i]);
261 for (; j < addend_end; j++) {
263 res.matrix_csc.values_.emplace_back(addend.matrix_csc.values_[j]);
264 res.matrix_csc.row_indices_.emplace_back(addend.matrix_csc.row_indices_[j]);
269 res.nnz_ = res.matrix_csc.row_indices_.size();
270 res.matrix_csc.col_ptrs_[column + 1] = res.nnz_;
276 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix operator+()" << std::endl;
277 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
278 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
279 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
280 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
288 template<
typename ITYPE,
typename VTYPE>
290 if (rows_ != subtrahend.rows_ || cols_ != subtrahend.cols_) {
291 std::cerr <<
"MUI Error [matrix_arithmetic.h]: matrix size mismatch during matrix subtraction" << std::endl;
295 if (subtrahend.matrix_format_ != matrix_format_) {
299 if (subtrahend.matrix_format_ == format::COO) {
300 subtrahend.
sort_coo(
true,
true,
"overwrite");
301 }
else if (subtrahend.matrix_format_ == format::CSR) {
302 subtrahend.
sort_csr(
true,
"overwrite");
303 }
else if (subtrahend.matrix_format_ == format::CSC) {
304 subtrahend.
sort_csc(
true,
"overwrite");
306 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised subtrahend matrix format for matrix operator-()" << std::endl;
307 std::cerr <<
" Please set the subtrahend matrix_format_ as:" << std::endl;
308 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
309 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
310 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
316 if (!this->is_sorted_unique(
"matrix_arithmetic.h",
"operator-()")){
317 if (matrix_format_ == format::COO) {
318 this->sort_coo(
true,
true,
"overwrite");
319 }
else if (matrix_format_ == format::CSR) {
320 this->sort_csr(
true,
"overwrite");
321 }
else if (matrix_format_ == format::CSC) {
322 this->sort_csc(
true,
"overwrite");
324 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix operator-()" << std::endl;
325 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
326 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
327 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
328 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
336 if (matrix_format_ == format::COO) {
338 std::vector<VTYPE> subtrahend_value;
339 subtrahend_value.reserve(subtrahend.matrix_coo.values_.size());
341 for (VTYPE &element : subtrahend.matrix_coo.values_) {
342 subtrahend_value.emplace_back(element*(-1));
346 res.matrix_coo.values_.reserve(matrix_coo.values_.size() + subtrahend.matrix_coo.values_.size());
347 res.matrix_coo.row_indices_.reserve(matrix_coo.row_indices_.size() + subtrahend.matrix_coo.row_indices_.size());
348 res.matrix_coo.col_indices_.reserve(matrix_coo.col_indices_.size() + subtrahend.matrix_coo.col_indices_.size());
351 res.matrix_coo.values_ = std::vector<VTYPE>(matrix_coo.values_.begin(), matrix_coo.values_.end());
352 res.matrix_coo.row_indices_ = std::vector<ITYPE>(matrix_coo.row_indices_.begin(), matrix_coo.row_indices_.end());
353 res.matrix_coo.col_indices_ = std::vector<ITYPE>(matrix_coo.col_indices_.begin(), matrix_coo.col_indices_.end());
356 res.matrix_coo.values_.insert(res.matrix_coo.values_.end(), subtrahend_value.begin(), subtrahend_value.end());
357 res.matrix_coo.row_indices_.insert(res.matrix_coo.row_indices_.end(), subtrahend.matrix_coo.row_indices_.begin(), subtrahend.matrix_coo.row_indices_.end());
358 res.matrix_coo.col_indices_.insert(res.matrix_coo.col_indices_.end(), subtrahend.matrix_coo.col_indices_.begin(), subtrahend.matrix_coo.col_indices_.end());
363 }
else if (matrix_format_ == format::CSR) {
366 res.matrix_csr.values_.reserve(matrix_csr.values_.size() + subtrahend.matrix_csr.values_.size());
367 res.matrix_csr.row_ptrs_.resize(rows_ + 1);
368 res.matrix_csr.col_indices_.reserve(matrix_csr.col_indices_.size() + subtrahend.matrix_csr.col_indices_.size());
371 while (row < rows_) {
372 ITYPE start = matrix_csr.row_ptrs_[row];
373 ITYPE end = matrix_csr.row_ptrs_[row + 1];
375 ITYPE subtrahend_start = subtrahend.matrix_csr.row_ptrs_[row];
376 ITYPE subtrahend_end = subtrahend.matrix_csr.row_ptrs_[row + 1];
378 res.matrix_csr.row_ptrs_[0] = 0;
382 ITYPE j = subtrahend_start;
383 while (i < end && j < subtrahend_end) {
384 ITYPE col = matrix_csr.col_indices_[i];
385 ITYPE subtrahend_col = subtrahend.matrix_csr.col_indices_[j];
387 if (col == subtrahend_col) {
390 res.matrix_csr.values_.emplace_back(matrix_csr.values_[i] - subtrahend.matrix_csr.values_[j]);
391 res.matrix_csr.col_indices_.emplace_back(col);
395 }
else if (col < subtrahend_col) {
398 res.matrix_csr.values_.emplace_back(matrix_csr.values_[i]);
399 res.matrix_csr.col_indices_.emplace_back(col);
405 res.matrix_csr.values_.emplace_back(-subtrahend.matrix_csr.values_[j]);
406 res.matrix_csr.col_indices_.emplace_back(subtrahend_col);
413 for (; i < end; i++) {
415 res.matrix_csr.values_.emplace_back(matrix_csr.values_[i]);
416 res.matrix_csr.col_indices_.emplace_back(matrix_csr.col_indices_[i]);
421 for (; j < subtrahend_end; j++) {
423 res.matrix_csr.values_.emplace_back(-subtrahend.matrix_csr.values_[j]);
424 res.matrix_csr.col_indices_.emplace_back(subtrahend.matrix_csr.col_indices_[j]);
429 res.nnz_ = res.matrix_csr.col_indices_.size();
430 res.matrix_csr.row_ptrs_[row + 1] = res.nnz_;
435 }
else if (matrix_format_ == format::CSC) {
438 res.matrix_csc.values_.reserve(matrix_csc.values_.size() + subtrahend.matrix_csc.values_.size());
439 res.matrix_csc.row_indices_.reserve(matrix_csc.row_indices_.size() + subtrahend.matrix_csc.row_indices_.size());
440 res.matrix_csc.col_ptrs_.resize(cols_ + 1);
443 while (column < cols_) {
444 ITYPE start = matrix_csc.col_ptrs_[column];
445 ITYPE end = matrix_csc.col_ptrs_[column + 1];
447 ITYPE subtrahend_start = subtrahend.matrix_csc.col_ptrs_[column];
448 ITYPE subtrahend_end = subtrahend.matrix_csc.col_ptrs_[column + 1];
450 res.matrix_csc.col_ptrs_[0] = 0;
454 ITYPE j = subtrahend_start;
455 while (i < end && j < subtrahend_end) {
456 ITYPE row = matrix_csc.row_indices_[i];
457 ITYPE subtrahend_row = subtrahend.matrix_csc.row_indices_[j];
459 if (row == subtrahend_row) {
462 res.matrix_csc.values_.emplace_back(matrix_csc.values_[i] - subtrahend.matrix_csc.values_[j]);
463 res.matrix_csc.row_indices_.emplace_back(row);
467 }
else if (row < subtrahend_row) {
470 res.matrix_csc.values_.emplace_back(matrix_csc.values_[i]);
471 res.matrix_csc.row_indices_.emplace_back(row);
477 res.matrix_csc.values_.emplace_back(-subtrahend.matrix_csc.values_[j]);
478 res.matrix_csc.row_indices_.emplace_back(subtrahend_row);
485 for (; i < end; i++) {
487 res.matrix_csc.values_.emplace_back(matrix_csc.values_[i]);
488 res.matrix_csc.row_indices_.emplace_back(matrix_csc.row_indices_[i]);
493 for (; j < subtrahend_end; j++) {
495 res.matrix_csc.values_.emplace_back(-subtrahend.matrix_csc.values_[j]);
496 res.matrix_csc.row_indices_.emplace_back(subtrahend.matrix_csc.row_indices_[j]);
501 res.nnz_ = res.matrix_csc.row_indices_.size();
502 res.matrix_csc.col_ptrs_[column + 1] = res.nnz_;
508 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix operator-()" << std::endl;
509 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
510 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
511 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
512 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
520 template<
typename ITYPE,
typename VTYPE>
523 if (cols_ != multiplicand.rows_) {
524 std::cerr <<
"MUI Error [matrix_arithmetic.h]: matrix size mismatch during matrix multiplication" << std::endl;
528 if (multiplicand.matrix_format_ != matrix_format_) {
532 if (multiplicand.matrix_format_ == format::COO) {
533 multiplicand.
sort_coo(
true,
true,
"overwrite");
534 }
else if (multiplicand.matrix_format_ == format::CSR) {
535 multiplicand.
sort_csr(
true,
"overwrite");
536 }
else if (multiplicand.matrix_format_ == format::CSC) {
537 multiplicand.
sort_csc(
true,
"overwrite");
539 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised multiplicand matrix format for matrix operator*()" << std::endl;
540 std::cerr <<
" Please set the multiplicand matrix_format_ as:" << std::endl;
541 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
542 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
543 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
549 if (!this->is_sorted_unique(
"matrix_arithmetic.h",
"operator*()")){
550 if (matrix_format_ == format::COO) {
551 this->sort_coo(
true,
true,
"overwrite");
552 }
else if (matrix_format_ == format::CSR) {
553 this->sort_csr(
true,
"overwrite");
554 }
else if (matrix_format_ == format::CSC) {
555 this->sort_csc(
true,
"overwrite");
557 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix operator*()" << std::endl;
558 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
559 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
560 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
561 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
569 if (matrix_format_ == format::COO) {
572 res.matrix_coo.values_.reserve((matrix_coo.values_.size() <= multiplicand.matrix_coo.values_.size()) ? multiplicand.matrix_coo.values_.size() : matrix_coo.values_.size());
573 res.matrix_coo.row_indices_.reserve((matrix_coo.row_indices_.size() <= multiplicand.matrix_coo.row_indices_.size()) ? multiplicand.matrix_coo.row_indices_.size() : matrix_coo.row_indices_.size());
574 res.matrix_coo.col_indices_.reserve((matrix_coo.col_indices_.size() <= multiplicand.matrix_coo.col_indices_.size()) ? multiplicand.matrix_coo.col_indices_.size() : matrix_coo.col_indices_.size());
576 for (ITYPE i = 0; i < static_cast<ITYPE>(matrix_coo.row_indices_.size()); ++i) {
577 for (ITYPE j = 0; j < static_cast<ITYPE>(multiplicand.matrix_coo.col_indices_.size()); ++j) {
578 if (matrix_coo.col_indices_[i] == multiplicand.matrix_coo.row_indices_[j]) {
580 VTYPE value = matrix_coo.values_[i] * multiplicand.matrix_coo.values_[j];
582 res.matrix_coo.values_.emplace_back(value);
583 res.matrix_coo.row_indices_.emplace_back(matrix_coo.row_indices_[i]);
584 res.matrix_coo.col_indices_.emplace_back(multiplicand.matrix_coo.col_indices_[j]);
593 res.nnz_ = res.matrix_coo.values_.size();
595 }
else if (matrix_format_ == format::CSR) {
598 res.matrix_csr.values_.reserve((matrix_csr.values_.size() <= multiplicand.matrix_csr.values_.size()) ? multiplicand.matrix_csr.values_.size() : matrix_csr.values_.size());
599 res.matrix_csr.row_ptrs_.resize(rows_+1);
600 res.matrix_csr.col_indices_.reserve((matrix_csr.col_indices_.size() <= multiplicand.matrix_csr.col_indices_.size()) ? multiplicand.matrix_csr.col_indices_.size() : matrix_csr.col_indices_.size());
603 std::vector<VTYPE> intermediate(multiplicand.cols_, 0.0);
605 res.matrix_csr.row_ptrs_[0] = 0;
608 for (ITYPE i = 0; i < rows_; ++i) {
610 std::fill(intermediate.begin(), intermediate.end(), 0.0);
612 ITYPE start = matrix_csr.row_ptrs_[i];
613 ITYPE end = matrix_csr.row_ptrs_[i + 1];
616 for (ITYPE j = start; j < end; ++j) {
618 ITYPE col = matrix_csr.col_indices_[j];
619 VTYPE value = matrix_csr.values_[j];
621 ITYPE multiplicand_start = multiplicand.matrix_csr.row_ptrs_[col];
622 ITYPE multiplicand_end = multiplicand.matrix_csr.row_ptrs_[col + 1];
625 for (ITYPE k = multiplicand_start; k < multiplicand_end; ++k) {
626 ITYPE multiplicand_col = multiplicand.matrix_csr.col_indices_[k];
627 VTYPE multiplicand_value = multiplicand.matrix_csr.values_[k];
628 intermediate[multiplicand_col] += value * multiplicand_value;
633 for (ITYPE j = 0; j < multiplicand.cols_; ++j) {
634 VTYPE result_value = intermediate[j];
636 res.matrix_csr.values_.emplace_back(result_value);
637 res.matrix_csr.col_indices_.emplace_back(j);
640 res.matrix_csr.row_ptrs_[i+1]=res.matrix_csr.values_.size();
643 res.nnz_ = res.matrix_csr.values_.size();
645 }
else if (matrix_format_ == format::CSC) {
648 res.matrix_csc.values_.reserve((matrix_csc.values_.size() <= multiplicand.matrix_csc.values_.size()) ? multiplicand.matrix_csc.values_.size() : matrix_csc.values_.size());
649 res.matrix_csc.row_indices_.reserve((matrix_csc.row_indices_.size() <= multiplicand.matrix_csc.row_indices_.size()) ? multiplicand.matrix_csc.row_indices_.size() : matrix_csc.row_indices_.size());
650 res.matrix_csc.col_ptrs_.resize(cols_+1);
653 std::vector<VTYPE> intermediate(rows_, 0.0);
655 res.matrix_csc.col_ptrs_[0] = 0;
658 for (ITYPE j = 0; j < multiplicand.cols_; ++j) {
660 std::fill(intermediate.begin(), intermediate.end(), 0.0);
662 ITYPE multiplicand_start = multiplicand.matrix_csc.col_ptrs_[j];
663 ITYPE multiplicand_end = multiplicand.matrix_csc.col_ptrs_[j + 1];
666 for (ITYPE k = multiplicand_start; k < multiplicand_end; ++k) {
668 ITYPE multiplicand_row = multiplicand.matrix_csc.row_indices_[k];
669 VTYPE multiplicand_value = multiplicand.matrix_csc.values_[k];
671 ITYPE start = matrix_csc.col_ptrs_[multiplicand_row];
672 ITYPE end = matrix_csc.col_ptrs_[multiplicand_row + 1];
675 for (ITYPE i = start; i < end; ++i) {
676 ITYPE row = matrix_csc.row_indices_[i];
677 VTYPE value = matrix_csc.values_[i];
678 intermediate[row] += value * multiplicand_value;
683 for (ITYPE i = 0; i < multiplicand.rows_; ++i) {
684 VTYPE result_value = intermediate[i];
686 res.matrix_csc.values_.emplace_back(result_value);
687 res.matrix_csc.row_indices_.emplace_back(i);
690 res.matrix_csc.col_ptrs_[j+1]=res.matrix_csc.values_.size();
693 res.nnz_ = res.matrix_csc.values_.size();
696 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix operator*()" << std::endl;
697 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
698 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
699 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
700 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
708 template <
typename ITYPE,
typename VTYPE>
709 template <
typename STYPE>
711 static_assert(std::is_convertible<STYPE, VTYPE>::value,
712 "MUI Error [matrix_arithmetic.h]: scalar type cannot be converted to matrix element type in scalar multiplication");
717 if (matrix_format_ == format::COO) {
719 for (VTYPE &element : res.matrix_coo.values_) {
724 }
else if (matrix_format_ == format::CSR) {
726 for (VTYPE &element : res.matrix_csr.values_) {
731 }
else if (matrix_format_ == format::CSC) {
733 for (VTYPE &element : res.matrix_csc.values_) {
739 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix scalar operator*()" << std::endl;
740 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
741 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
742 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
743 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
752 template<
typename ITYPE,
typename VTYPE,
typename STYPE>
754 return exist_mat * scalar;
758 template <
typename ITYPE,
typename VTYPE>
760 assert(((cols_ == 1)&&(exist_mat.cols_ == 1)) &&
761 "MUI Error [matrix_arithmetic.h]: dot_product function only works for column vectors");
766 "MUI Error [matrix_arithmetic.h]: result of dot_product function should be a scalar");
771 template <
typename ITYPE,
typename VTYPE>
773 if (rows_ != exist_mat.rows_ || cols_ != exist_mat.cols_) {
774 std::cerr <<
"MUI Error [matrix_arithmetic.h]: matrix size mismatch during matrix Hadamard product" << std::endl;
778 if (exist_mat.matrix_format_ != matrix_format_) {
781 if (!exist_mat.
is_sorted_unique(
"matrix_arithmetic.h",
"hadamard_product()")){
782 if (exist_mat.matrix_format_ == format::COO) {
783 exist_mat.
sort_coo(
true,
true,
"overwrite");
784 }
else if (exist_mat.matrix_format_ == format::CSR) {
785 exist_mat.
sort_csr(
true,
"overwrite");
786 }
else if (exist_mat.matrix_format_ == format::CSC) {
787 exist_mat.
sort_csc(
true,
"overwrite");
789 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised exist_mat matrix format for matrix hadamard_product()" << std::endl;
790 std::cerr <<
" Please set the exist_mat matrix_format_ as:" << std::endl;
791 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
792 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
793 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
799 if (!this->is_sorted_unique(
"matrix_arithmetic.h",
"hadamard_product()")){
800 if (matrix_format_ == format::COO) {
801 this->sort_coo(
true,
true,
"overwrite");
802 }
else if (matrix_format_ == format::CSR) {
803 this->sort_csr(
true,
"overwrite");
804 }
else if (matrix_format_ == format::CSC) {
805 this->sort_csc(
true,
"overwrite");
807 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix hadamard_product()" << std::endl;
808 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
809 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
810 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
811 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
819 if (matrix_format_ == format::COO) {
822 res.matrix_coo.values_.reserve(matrix_coo.values_.size() + exist_mat.matrix_coo.values_.size());
823 res.matrix_coo.row_indices_.reserve(matrix_coo.row_indices_.size() + exist_mat.matrix_coo.row_indices_.size());
824 res.matrix_coo.col_indices_.reserve(matrix_coo.col_indices_.size() + exist_mat.matrix_coo.col_indices_.size());
827 res.matrix_coo.values_ = std::vector<VTYPE>(matrix_coo.values_.begin(), matrix_coo.values_.end());
828 res.matrix_coo.row_indices_ = std::vector<ITYPE>(matrix_coo.row_indices_.begin(), matrix_coo.row_indices_.end());
829 res.matrix_coo.col_indices_ = std::vector<ITYPE>(matrix_coo.col_indices_.begin(), matrix_coo.col_indices_.end());
832 res.matrix_coo.values_.insert(res.matrix_coo.values_.end(), exist_mat.matrix_coo.values_.begin(), exist_mat.matrix_coo.values_.end());
833 res.matrix_coo.row_indices_.insert(res.matrix_coo.row_indices_.end(), exist_mat.matrix_coo.row_indices_.begin(), exist_mat.matrix_coo.row_indices_.end());
834 res.matrix_coo.col_indices_.insert(res.matrix_coo.col_indices_.end(), exist_mat.matrix_coo.col_indices_.begin(), exist_mat.matrix_coo.col_indices_.end());
837 res.
sort_coo(
true,
true,
"multiply");
839 }
else if (matrix_format_ == format::CSR) {
842 res.matrix_csr.values_.reserve(matrix_csr.values_.size() + exist_mat.matrix_csr.values_.size());
843 res.matrix_csr.row_ptrs_.resize(rows_ + 1);
844 res.matrix_csr.col_indices_.reserve(matrix_csr.col_indices_.size() + exist_mat.matrix_csr.col_indices_.size());
846 res.matrix_csr.row_ptrs_[0] = 0;
849 while (row < rows_) {
850 ITYPE start = matrix_csr.row_ptrs_[row];
851 ITYPE end = matrix_csr.row_ptrs_[row + 1];
853 ITYPE exist_mat_start = exist_mat.matrix_csr.row_ptrs_[row];
854 ITYPE exist_mat_end = exist_mat.matrix_csr.row_ptrs_[row + 1];
858 ITYPE j = exist_mat_start;
859 while (i < end && j < exist_mat_end) {
860 ITYPE col = matrix_csr.col_indices_[i];
861 ITYPE exist_mat_col = exist_mat.matrix_csr.col_indices_[j];
863 if (col == exist_mat_col) {
866 res.matrix_csr.values_.emplace_back(matrix_csr.values_[i] * exist_mat.matrix_csr.values_[j]);
867 res.matrix_csr.col_indices_.emplace_back(col);
871 }
else if (col < exist_mat_col) {
879 res.nnz_ = res.matrix_csr.col_indices_.size();
880 res.matrix_csr.row_ptrs_[row + 1] = res.nnz_;
885 }
else if (matrix_format_ == format::CSC) {
888 res.matrix_csc.values_.reserve(matrix_csc.values_.size() + exist_mat.matrix_csc.values_.size());
889 res.matrix_csc.row_indices_.reserve(matrix_csc.row_indices_.size() + exist_mat.matrix_csc.row_indices_.size());
890 res.matrix_csc.col_ptrs_.resize(cols_ + 1);
892 res.matrix_csc.col_ptrs_[0] = 0;
895 while (column < cols_) {
896 ITYPE start = matrix_csc.col_ptrs_[column];
897 ITYPE end = matrix_csc.col_ptrs_[column + 1];
899 ITYPE exist_mat_start = exist_mat.matrix_csc.col_ptrs_[column];
900 ITYPE exist_mat_end = exist_mat.matrix_csc.col_ptrs_[column + 1];
904 ITYPE j = exist_mat_start;
905 while (i < end && j < exist_mat_end) {
906 ITYPE row = matrix_csc.row_indices_[i];
907 ITYPE exist_mat_row = exist_mat.matrix_csc.row_indices_[j];
911 res.matrix_csc.values_.emplace_back(matrix_csc.values_[i] * exist_mat.matrix_csc.values_[j]);
912 res.matrix_csc.row_indices_.emplace_back(row);
915 }
else if (row < exist_mat_row) {
923 res.nnz_ = res.matrix_csc.row_indices_.size();
924 res.matrix_csc.col_ptrs_[column + 1] = res.nnz_;
930 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix hadamard_product()" << std::endl;
931 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
932 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
933 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
934 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
942 template <
typename ITYPE,
typename VTYPE>
947 if (matrix_format_ == format::COO) {
949 if (performSortAndUniqueCheck){
951 res.
sort_coo(
true,
true,
"overwrite");
957 }
else if (matrix_format_ == format::CSR) {
959 res.
format_conversion(
"CSC", performSortAndUniqueCheck, performSortAndUniqueCheck,
"overwrite");
963 }
else if (matrix_format_ == format::CSC) {
965 res.
format_conversion(
"CSR", performSortAndUniqueCheck, performSortAndUniqueCheck,
"overwrite");
970 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Unrecognised matrix format for matrix transpose()" << std::endl;
971 std::cerr <<
" Please set the matrix_format_ as:" << std::endl;
972 std::cerr <<
" format::COO: COOrdinate format" << std::endl;
973 std::cerr <<
" format::CSR (default): Compressed Sparse Row format" << std::endl;
974 std::cerr <<
" format::CSC: Compressed Sparse Column format" << std::endl;
983 template <
typename ITYPE,
typename VTYPE>
989 std::cerr <<
"MUI Error [matrix_arithmetic.h]: L & U Matrices must be null or same size of initial matrix in LU decomposition" << std::endl;
994 std::cerr <<
"MUI Error [matrix_arithmetic.h]: L & U Matrices must be empty in LU decomposition" << std::endl;
998 if (rows_ != cols_) {
999 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Only square matrix can perform LU decomposition" << std::endl;
1014 for (ITYPE i = 0; i < rows_; ++i) {
1016 for (ITYPE k = i; k < cols_; ++k) {
1018 for (ITYPE j = 0; j < i; ++j) {
1025 for (ITYPE k = i; k < rows_; k++) {
1027 L.
set_value(i, i,
static_cast<VTYPE
>(1.0));
1030 for (ITYPE j = 0; j < i; ++j) {
1034 "MUI Error [matrix_arithmetic.h]: Divide by zero assert for U.get_value(i, i)");
1042 template <
typename ITYPE,
typename VTYPE>
1048 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Q & R Matrices must be null in QR decomposition" << std::endl;
1052 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Q & R Matrices must be empty in QR decomposition" << std::endl;
1055 assert((rows_ >= cols_) &&
1056 "MUI Error [matrix_arithmetic.h]: number of rows of matrix should larger or equals to number of columns in QR decomposition");
1070 std::vector<VTYPE> r_diag (cols_);
1073 for (ITYPE c = 0; c <cols_; ++c) {
1077 for (ITYPE r = c; r < rows_; ++r)
1078 nrm = std::sqrt((nrm * nrm) + (mat_copy.
get_value(r, c) * mat_copy.
get_value(r, c)));
1080 if (nrm !=
static_cast<VTYPE
>(0.0)) {
1083 if (mat_copy.
get_value(c, c) <
static_cast<VTYPE
>(0.0))
1086 for (ITYPE r = c; r < rows_; ++r)
1092 for (ITYPE j = c + 1; j < cols_; ++j) {
1095 for (ITYPE r = c; r < rows_; ++r)
1099 for (ITYPE r = c; r < rows_; ++r)
1107 for (ITYPE c = cols_ - 1; c >= 0; --c) {
1108 Q.
set_value(c, c,
static_cast<VTYPE
>(1.0));
1110 for (ITYPE cc = c; cc < cols_; ++cc)
1111 if (mat_copy.
get_value(c, c) !=
static_cast<VTYPE
>(0.0)) {
1114 for (ITYPE r = c; r < rows_; ++r)
1118 for (ITYPE r = c; r < rows_; ++r)
1124 for (ITYPE c = 0; c < cols_; ++c)
1125 for (ITYPE r = 0; r < rows_; ++r)
1133 template <
typename ITYPE,
typename VTYPE>
1135 if (rows_ != cols_) {
1136 std::cerr <<
"MUI Error [matrix_arithmetic.h]: Matrix must be square to find its inverse" << std::endl;
1143 for (ITYPE r = 0; r < rows_; ++r) {
1146 VTYPE max_value=
static_cast<VTYPE
>(-1.0);
1150 for (ITYPE rb = r; rb < rows_; ++rb) {
1151 const VTYPE tmp = std::abs(mat_copy.
get_value(rb, r));
1160 "MUI Error [matrix_arithmetic.h]: Divide by zero assert for mat_copy.get_value(max_row, r). Cannot perform matrix invert due to singular matrix.");
1163 for (ITYPE c = 0; c < cols_; ++c)
1170 const ITYPE indx = ppivot;
1173 for (ITYPE c = 0; c < cols_; ++c)
1176 const VTYPE diag = mat_copy.
get_value(r, r);
1178 for (ITYPE c = 0; c < cols_; ++c) {
1183 for (ITYPE rr = 0; rr < rows_; ++rr)
1185 const VTYPE off_diag = mat_copy.
get_value(rr, r);
1187 for (ITYPE c = 0; c < cols_; ++c) {
1201 template<
typename ITYPE,
typename VTYPE>
1203 assert((matrix_format_ == format::COO) &&
1204 "MUI Error [matrix_arithmetic.h]: index_reinterpretation() is for COO format only.");
1206 std::swap(matrix_coo.row_indices_, matrix_coo.col_indices_);
1207 ITYPE temp_index = rows_;
1214 template<
typename ITYPE,
typename VTYPE>
1216 assert(((matrix_format_ == format::CSR) || (matrix_format_ == format::CSC)) &&
1217 "MUI Error [matrix_arithmetic.h]: format_reinterpretation() is for CSR or CSC format.");
1219 if (matrix_format_ == format::CSR) {
1221 matrix_csc.col_ptrs_.swap(matrix_csr.row_ptrs_);
1222 matrix_csc.row_indices_.swap(matrix_csr.col_indices_);
1223 matrix_csc.values_.swap(matrix_csr.values_);
1225 ITYPE temp_index = rows_;
1229 matrix_format_ = format::CSC;
1231 matrix_csr.row_ptrs_.clear();
1232 matrix_csr.col_indices_.clear();
1233 matrix_csr.values_.clear();
1235 }
else if (matrix_format_ == format::CSC) {
1237 matrix_csr.row_ptrs_.swap(matrix_csc.col_ptrs_);
1238 matrix_csr.col_indices_.swap(matrix_csc.row_indices_);
1239 matrix_csr.values_.swap(matrix_csc.values_);
1241 ITYPE temp_index = rows_;
1245 matrix_format_ = format::CSR;
1247 matrix_csc.col_ptrs_.clear();
1248 matrix_csc.row_indices_.clear();
1249 matrix_csc.values_.clear();
void index_reinterpretation()
Definition: matrix_arithmetic.h:1202
ITYPE get_rows() const
Definition: matrix_io_info.h:579
bool is_sorted_unique(const std::string &={}, const std::string &={}) const
Definition: matrix_io_info.h:701
sparse_matrix< ITYPE, VTYPE > operator*(sparse_matrix< ITYPE, VTYPE > &)
Definition: matrix_arithmetic.h:521
VTYPE dot_product(sparse_matrix< ITYPE, VTYPE > &) const
Definition: matrix_arithmetic.h:759
void sort_csr(bool=false, const std::string &="overwrite")
Definition: matrix_manipulation.h:1037
sparse_matrix< ITYPE, VTYPE > transpose(bool=true) const
Definition: matrix_arithmetic.h:943
void qr_decomposition(sparse_matrix< ITYPE, VTYPE > &, sparse_matrix< ITYPE, VTYPE > &) const
Definition: matrix_arithmetic.h:1043
void set_value(ITYPE, ITYPE, VTYPE, bool=true)
Definition: matrix_manipulation.h:292
sparse_matrix< ITYPE, VTYPE > operator+(sparse_matrix< ITYPE, VTYPE > &)
Definition: matrix_arithmetic.h:62
VTYPE get_value(ITYPE, ITYPE) const
Definition: matrix_io_info.h:523
sparse_matrix< ITYPE, VTYPE > hadamard_product(sparse_matrix< ITYPE, VTYPE > &)
Definition: matrix_arithmetic.h:772
void resize(ITYPE, ITYPE)
Definition: matrix_manipulation.h:62
ITYPE get_cols() const
Definition: matrix_io_info.h:585
sparse_matrix< ITYPE, VTYPE > inverse() const
Definition: matrix_arithmetic.h:1134
void swap_elements(ITYPE, ITYPE, ITYPE, ITYPE)
Definition: matrix_manipulation.h:407
void sort_coo(bool=true, bool=false, const std::string &="overwrite")
Definition: matrix_manipulation.h:808
void format_conversion(const std::string &="COO", bool=true, bool=false, const std::string &="overwrite")
Definition: matrix_manipulation.h:639
void format_reinterpretation()
Definition: matrix_arithmetic.h:1215
sparse_matrix< ITYPE, VTYPE > operator-(sparse_matrix< ITYPE, VTYPE > &)
Definition: matrix_arithmetic.h:289
bool empty() const
Definition: matrix_io_info.h:665
void lu_decomposition(sparse_matrix< ITYPE, VTYPE > &, sparse_matrix< ITYPE, VTYPE > &) const
Definition: matrix_arithmetic.h:984
void sort_csc(bool=false, const std::string &="overwrite")
Definition: matrix_manipulation.h:1168
u u u u u u min
Definition: dim.h:289
sparse_matrix< ITYPE, VTYPE > operator*(const STYPE &scalar, const sparse_matrix< ITYPE, VTYPE > &exist_mat)
Definition: matrix_arithmetic.h:753
SCALAR sum(vexpr< E, SCALAR, D > const &u)
Definition: point.h:362
void swap(storage< Args... > &lhs, storage< Args... > &rhs)
Definition: dynstorage.h:234