50 #ifndef MUI_CONJUGATE_GRADIENT_H_
51 #define MUI_CONJUGATE_GRADIENT_H_
59 template<
typename ITYPE,
typename VTYPE>
63 cg_solve_tol_(cg_solve_tol),
64 cg_max_iter_(cg_max_iter),
66 assert(b_.get_cols() == 1 &&
67 "MUI Error [solver_cg.h]: Number of column of b matrix must be 1");
68 x_.resize(A_.get_rows(),1);
69 r_.resize(A_.get_rows(),1);
70 z_.resize(A_.get_rows(),1);
71 p_.resize(A_.get_rows(),1);
75 template<
typename ITYPE,
typename VTYPE>
79 cg_solve_tol_(cg_solve_tol),
80 cg_max_iter_(cg_max_iter),
82 assert(A_.get_rows() == b_.get_rows() &&
83 "MUI Error [solver_cg.h]: Number of rows of A matrix must be the same as the number of rows of b matrix");
84 b_column_.resize(b_.get_rows(),1);
85 x_.resize(b_.get_rows(),b_.get_cols());
86 x_init_column_.resize(b_.get_rows(),1);
90 template<
typename ITYPE,
typename VTYPE>
110 template<
typename ITYPE,
typename VTYPE>
116 b_column_.set_zero();
117 x_init_column_.set_zero();
129 template<
typename ITYPE,
typename VTYPE>
131 if (!x_init.
empty()){
132 assert(((x_init.
get_rows() == x_.get_rows()) && (x_init.
get_cols() == x_.get_cols())) &&
133 "MUI Error [solver_cg.h]: Size of x_init matrix mismatch with size of x_ matrix");
149 tempZ = M_->apply(z_);
157 VTYPE r_norm0 = r_.dot_product(z_);
159 "MUI Error [solver_cg.h]: Divide by zero assert for r_norm0");
160 VTYPE r_norm = r_norm0;
161 VTYPE r_norm_rel = std::sqrt(r_norm/r_norm0);
164 if(cg_max_iter_ == 0) {
167 kIter = cg_max_iter_;
170 ITYPE acturalKIterCount = 0;
172 for (ITYPE k = 0; k < kIter; ++k) {
177 "MUI Error [solver_cg.h]: Divide by zero assert for p_dot_Ap");
178 VTYPE alpha = r_norm / p_dot_Ap;
179 for (ITYPE j = 0; j < A_.get_rows(); ++j) {
180 x_.add_scalar(j, 0, (alpha * (p_.get_value(j,0))));
181 r_.subtract_scalar(j, 0, (alpha * (Ap.
get_value(j,0))));
189 tempZ = M_->apply(z_);
194 VTYPE updated_r_norm = r_.dot_product(z_);
196 "MUI Error [solver_cg.h]: Divide by zero assert for r_norm");
197 VTYPE beta = updated_r_norm / r_norm;
198 r_norm = updated_r_norm;
199 for (ITYPE j = 0; j < A_.get_rows(); ++j) {
200 p_.set_value(j, 0, (z_.get_value(j,0)+(beta*p_.get_value(j,0))));
203 r_norm_rel = std::sqrt(r_norm/r_norm0);
204 if (r_norm_rel <= cg_solve_tol_) {
208 return std::make_pair(acturalKIterCount,r_norm_rel);
212 template<
typename ITYPE,
typename VTYPE>
214 if (!x_init.
empty()){
215 assert(((x_init.
get_rows() == b_.get_rows()) && (x_init.
get_cols() == b_.get_cols())) &&
216 "MUI Error [solver_cg.h]: Size of x_init matrix mismatch with size of b_ matrix");
219 std::pair<ITYPE, VTYPE> cgReturn;
220 for (ITYPE j = 0; j < b_.get_cols(); ++j) {
221 b_column_.set_zero();
222 b_column_ = b_.segment(0,(b_.get_rows()-1),j,j);
224 if (!x_init.
empty()) {
225 x_init_column_.set_zero();
228 std::pair<ITYPE, VTYPE> cgReturnTemp = cg.
solve(x_init_column_);
229 if (cgReturn.first < cgReturnTemp.first)
230 cgReturn.first = cgReturnTemp.first;
231 cgReturn.second += cgReturnTemp.second;
234 for (ITYPE i = 0; i < x_column.
get_rows(); ++i) {
235 x_.set_value(i, j, x_column.
get_value(i,0));
238 cgReturn.second /= b_.get_cols();
244 template<
typename ITYPE,
typename VTYPE>
250 template<
typename ITYPE,
typename VTYPE>
std::pair< ITYPE, VTYPE > solve(sparse_matrix< ITYPE, VTYPE >=sparse_matrix< ITYPE, VTYPE >())
Definition: solver_cg.h:130
sparse_matrix< ITYPE, VTYPE > getSolution()
Definition: solver_cg.h:245
~conjugate_gradient_1d()
Definition: solver_cg.h:91
conjugate_gradient_1d(sparse_matrix< ITYPE, VTYPE >, sparse_matrix< ITYPE, VTYPE >, VTYPE=1e-6, ITYPE=0, preconditioner< ITYPE, VTYPE > *=nullptr)
Definition: solver_cg.h:60
sparse_matrix< ITYPE, VTYPE > getSolution()
Definition: solver_cg.h:251
std::pair< ITYPE, VTYPE > solve(sparse_matrix< ITYPE, VTYPE >=sparse_matrix< ITYPE, VTYPE >())
Definition: solver_cg.h:213
conjugate_gradient(sparse_matrix< ITYPE, VTYPE >, sparse_matrix< ITYPE, VTYPE >, VTYPE=1e-6, ITYPE=0, preconditioner< ITYPE, VTYPE > *=nullptr)
Definition: solver_cg.h:76
~conjugate_gradient()
Definition: solver_cg.h:111
Definition: preconditioner.h:55
ITYPE get_rows() const
Definition: matrix_io_info.h:579
VTYPE dot_product(sparse_matrix< ITYPE, VTYPE > &) const
Definition: matrix_arithmetic.h:759
void set_zero()
Definition: matrix_manipulation.h:418
VTYPE get_value(ITYPE, ITYPE) const
Definition: matrix_io_info.h:523
sparse_matrix< ITYPE, VTYPE > segment(ITYPE, ITYPE, ITYPE, ITYPE, bool=true)
Definition: matrix_manipulation.h:153
ITYPE get_cols() const
Definition: matrix_io_info.h:585
void copy(const sparse_matrix< ITYPE, VTYPE > &)
Definition: matrix_manipulation.h:79
bool empty() const
Definition: matrix_io_info.h:665
u u u u u u min
Definition: dim.h:289
SCALAR max(vexpr< E, SCALAR, D > const &u)
Definition: point.h:350