项目中需要求解动态增长的线性方程组,方程个数会不断增加,需要使用高斯消元法来求解,每个方程添加完后,已知的变量值。
代码中分了三个类,表达式expression,方程 equation, 高斯消元求解器gauss_eliminator。
这段代码还是有问题,如果需要最新的,请访问我的github https://github.com/tfjiang/gauss_elimination
1 #ifndef GAUSS_ELIMINATION_H
2 #define GAUSS_ELIMINATION_H
3
4 #include <map>
5 #include <vector>
6 #include <list>
7 #include <deque>
9 #include <iostream>
10 #include <boost/unordered_map.hpp>
11 #include <boost/static_assert.hpp>
12 #include <boost/type_traits/is_same.hpp>
13 #include <zjucad/matrix/matrix.h> // a private matrix library
14
15 namespace jtf{
16 namespace algorithm{
17
18 /**
19 * @brief store an expression A * i, where A is the coeffieient
20 *
21 */
22 template <typename T>
23 class expression{
24 public:
25 expression():index(-1),coefficient(0){}
26 expression(const size_t & index_, const T & coefficient_)
27 : index(index_), coefficient(coefficient_){}
28 size_t index;
29 T coefficient;
30 /**
31 * @brief operator < is used to determin which expression ha
32 *
33 * @param other input other expression
34 * @return bool
35 */
36 bool operator < (const expression<T> & b) const{
37 return index < b.index;
38 }
39
40 /**
41 * @brief To determin whether coefficent of this expression is zeros
42 *
43 * @return bool
44 */
45 bool is_zero() const{ return fabs(coefficient) < 1e-6;}
46 };
47
48 /**
49 * @brief make expression with node_idx and value
50 *
51 * @param node_idx
52 * @param value
53 * @return expression<T>
54 */
55 template <typename T>
56 expression<T> make_expression(const size_t & node_idx, const T & value)
57 {
58 expression<T> temp(node_idx, value);
59 return temp;
60 }
61
62 /**
63 * @brief equation class
64 *
65 */
66 template <typename T>
67 class equation{
68 public:
69 typedef typename std::list<expression<T> >::const_iterator eq_const_iterator;
70 typedef typename std::list<expression<T> >::iterator eq_iterator;
71
72 equation():value_(static_cast<T>(0)){}
73
74 eq_const_iterator begin() const {return e_vec_.begin();}
75 eq_iterator begin(){return e_vec_.begin();}
76
77 eq_const_iterator end() const {return e_vec_.end();}
78 eq_iterator end(){return e_vec_.end();}
79
80 /**
81 * @brief used to standardizate the equation, sort the expressions and merge
82 * similar items and normalization to make the first item's coefficient
83 * equal 1
84 *
85 * @return int
86 */
87 int standardization(){
88 sort_equation();
89 merge_similar();
90 normalization();
91 return 0;
92 }
93
94 /**
95 * @brief update the equation with given node idx and value
96 *
97 * @param node_idx input node idx
98 * @param node_value input node value
99 * @return int
100 */
101 int update(const size_t & node_idx, const T & node_value);
102 /**
103 * @brief sort equation accronding to the expressions
104 *
105 * @return int
106 */
107 int sort_equation(){
108 e_vec_.sort();
109 return 0;
110 }
111
112 /**
113 * @brief merge similar items, WARNING. this function should be used after
114 * sorting.
115 *
116 * @return int
117 */
118 int merge_similar();
119
120 /**
121 * @brief normalize the equations: to make the first expression's coefficient
122 * equals 1
123 *
124 * @return int
125 */
126 int normalization();
127
128 /**
129 * @brief get the equation value
130 *
131 * @return const T
132 */
133 const T& get_value() const {return value_;}
134
135 /**
136 * @brief get the state of equation:
137 * if there are no expressions,
138 * if value != 0 return -1; error equation
139 * else return 0; cleared
140 * else
141 * if there is one expression return 1; calculated finial node
142 * else return 2; not finished
143 *
144 * @return int
145 */
146 int state() const;
147
148
149 /**
150 * @brief define the minus operation
151 *
152 * @param b
153 * @return equation<T>
154 */
155 equation<T> & operator -= (const equation<T> & b);
156
157 /**
158 * @brief define the output operation
159 *
160 * @param eq
161 * @return equation<T>
162 */
163 friend std::ostream& operator << (std::ostream &output,
164 const equation<T> &eq)
165 {
166 if(eq.e_vec_.empty()){
167 output << "# ------ empty expressions with value = " << eq.value() << std::endl;
168 }else{
169 output << "# ------ SUM coeff * index , val" << std::endl
170 << "# ------ ";
171 for(equation<T>::eq_const_iterator eqcit = eq.begin(); eqcit != eq.end();){
172 const expression<T> & exp = *eqcit;
173 output << exp.coefficient << "*X" << exp.index;
174 ++eqcit;
175 if(eqcit == eq.end())
176 output << " = ";
177 else
178 output << " + ";
179 }
180 output << eq.value() << std::endl;
181 }
182 return output;
183 }
184
185
186 /**
187 * @brief get the first expression idx
188 *
189 * @return size_t
190 */
191 size_t get_prim_idx() const {
192 return e_vec_.front().index;
193 }
194
195 int add_expression(const expression<T> & exp){
196 e_vec_.push_back(exp);
197 return 0;
198 }
199 T& value() {return value_;}
200 const T& value() const {return value_;}
201 std::list<expression<T> > e_vec_;
202 private:
203 T value_;
204 };
205
206 template <typename T>
207 int equation<T>::merge_similar(){
208 typedef typename std::list<expression<T> >::iterator leit;
209 leit current = e_vec_.begin();
210 leit next = current;
211 ++next;
212 while(next != e_vec_.end()){
213 if(next->index > current->index) {
214 current = next++;
215 continue;
216 }else if(next->index == current->index){
217 current->coefficient += next->coefficient;
218 e_vec_.erase(next++);
219 if(fabs(current->coefficient) < 1e-6){
220 e_vec_.erase(current++);
221 ++next;
222 }
223 }else{
224 std::cerr << "# [error] merge similar function should only be called "
225 << "after sorting." << std::endl;
226 return __LINE__;
227 }
228 }
229 return 0;
230 }
231
232 template <typename T>
233 int equation<T>::normalization()
234 {
235 const T coeff = e_vec_.front().coefficient;
236 if(fabs(coeff) < 1e-6){
237 if(e_vec_.empty()){
238 //std::cerr << "# [info] this equation is empty." << std::endl;
239 return 0;
240 }else
241 std::cerr << "# [error] this expression should be removed." << std::endl;
242 return __LINE__;
243 }
244
245 value() /= coeff;
246 for(typename std::list<expression<T> >::iterator lit = e_vec_.begin();
247 lit != e_vec_.end(); ++lit){
248 expression<T> & ep = *lit;
249 ep.coefficient /= coeff;
250 }
251 return 0;
252 }
253
254 template <typename T>
255 int equation<T>::state() const
256 {
257 if(e_vec_.empty()){
258 if(fabs(value()) < 1e-8)
259 return 0; // is cleared
260 return -1; // is conflicted
261 }else{
262 if(e_vec_.size() == 1)
263 return 1; // finial variant
264 else
265 return 2; // not ready
266 }
267 }
268
269 template <typename T>
270 equation<T> & equation<T>::operator -= (const equation<T> & b )
271 {
272 if(&b == this) {
273 e_vec_.clear();
274 value() = 0;
275 return *this;
276 }
277
278 for(typename std::list<expression<T> >::const_iterator lecit_b =
279 b.e_vec_.begin(); lecit_b != b.e_vec_.end(); ++lecit_b){
280 const expression<T> & exp = *lecit_b;
281 const size_t &node_idx = exp.index;
282 assert(fabs(exp.coefficient) > 1e-6);
283 bool is_found = false;
284
285 for(typename std::list<expression<T> >::iterator leit_a = e_vec_.begin();
286 leit_a != e_vec_.end(); ++leit_a){
287 expression<T> & exp_a = *leit_a;
288 if(exp_a.index == node_idx){
289 exp_a.coefficient -= exp.coefficient;
290 // zeros
291 if(fabs(exp_a.coefficient) < 1e-6) {
292 e_vec_.erase(leit_a);
293 is_found = true;
294 break;
295 }
296 }
297 }
298 if(!is_found){
299 e_vec_.push_back(make_expression(node_idx, -1 * exp.coefficient));
300 }
301 }
302
303 value() -= b.value();
304 sort_equation();
305 normalization();
306
307 return *this;
308 }
309
310 template <typename T>
311 int equation<T>::update(const size_t & node_idx, const T & node_value)
312 {
313 for(typename std::list<expression<T> >::iterator leit = e_vec_.begin();
314 leit != e_vec_.end();){
315 expression<T> & exp = *leit;
316 if(exp.index == node_idx){
317 value() -= exp.coefficient * node_value;
318 e_vec_.erase(leit++);
319 }
320 ++leit;
321 }
322 return 0;
323 }
324
325 //! @brief this class only handle Ai+Bi=Ci
326 template <typename T>
327 class gauss_eliminator{
328 BOOST_MPL_ASSERT_MSG((boost::is_same<T,double>::value ) ||
329 (boost::is_same<T,float>::value ),
330 NON_FLOAT_TYPES_ARE_NOT_ALLOWED, (void));
331 public:
332 /**
333 * @brief construct gauss_eliminator class
334 *
335 * @param nodes input nodes
336 * @param node_flag input node_flag which will be tagged as true if the
337 * corresponding node is known
338 */
339 gauss_eliminator(zjucad::matrix::matrix<T> & nodes,
340 std::vector<bool> & node_flag)
341 :nodes_(nodes), node_flag_(node_flag){
342 idx2equation_.resize(nodes_.size());
343 }
344
345 /**
346 * @brief add equation to gauss_eliminator, every time an equation is added,
347 * eliminate function is called.
348 *
349 * @param input equation
350 * @return int
351 */
352 int add_equation(const equation<T> & e);
353
354 /**
355 * @brief This function will start to eliminate equations above all added equations
356 *
357 * @return int return 0 if works fine, or return non-zeros
358 */
359 int eliminate();
360
361
362 /**
363 * @brief update the equation, it will check all variant, if a variant is
364 * already known, update this equation
365 *
366 * @param eq input equation
367 * @return int return 0 if nothing changes, or retunr 1;
368 */
369 int update_equation(equation<T> & eq);
370
371 private:
372 zjucad::matrix::matrix<T> & nodes_;
373 std::vector<bool> & node_flag_;
374 std::list<equation<T> > es;
375
376 typedef typename std::list<equation<T> >::iterator equation_ptr;
377 std::vector<std::list<equation_ptr> > idx2equation_;
378
379 // this map store the smallest expression
380 typedef typename std::map<size_t, std::list<equation_ptr> >::iterator prime_eq_ptr;
381 std::map<size_t, std::list<equation_ptr> > prime_idx2equation_;
382 };
383
384 template <typename T>
385 int gauss_eliminator<T>::add_equation(const equation<T> & e){
386 es.push_back(e);
387 equation<T> & e_back = es.back();
388 for(typename equation<T>::eq_iterator eit = e_back.begin(); eit != e_back.end(); ){
389 if(node_flag_[eit->index]){
390 e_back.value() -= nodes_[eit->index] * eit->coefficient;
391 e_back.e_vec_.erase(eit++);
392 }else
393 ++eit;
394 }
395 if(e_back.state() == 0){// this equation is cleared
396 es.pop_back();
397 return 0;
398 }else if(e_back.state() == -1){
399 std::cerr << "# [error] strange conflict equation: " << std::endl;
400 std::cerr << e;
401 es.pop_back();
402 return __LINE__;
403 }
404 e_back.standardization();
405
406 for(typename equation<T>::eq_const_iterator it = e.begin();
407 it != e.end(); ++it){
408 equation_ptr end_ptr = es.end();
409 idx2equation_[it->index].push_back(--end_ptr);
410 }
411
412 equation_ptr end_ptr = es.end();
413 prime_idx2equation_[e_back.get_prim_idx()].push_back(--end_ptr);
414 eliminate();
415 return 0;
416 }
417
418 template <typename T>
419 int gauss_eliminator<T>::update_equation(equation<T> & eq)
420 {
421 for(typename std::list<expression<T> >::iterator it = eq.e_vec_.begin();
422 it != eq.e_vec_.end(); ){
423 const expression<T> & exp = *it;
424 if(node_flag_[exp.index]){
425 eq.value() -= exp.coefficient * nodes_[exp.index];
426 eq.e_vec_.erase(it++);
427 }else
428 ++it;
429 }
430 eq.standardization();
431 return 0;
432 }
433
434 template <typename T>
435 int gauss_eliminator<T>::eliminate()
436 {
437 std::cerr << std::endl;
438 while(1){
439 bool is_modified = false;
440
441 for(prime_eq_ptr ptr = prime_idx2equation_.begin();
442 ptr != prime_idx2equation_.end();)
443 {
444 std::list<equation_ptr> & dle = ptr->second;
445 if(dle.empty()) {
446 prime_idx2equation_.erase(ptr++);
447 is_modified = true;
448 continue;
449 }else if(dle.size() == 1){ // contain only one equation
450 const int state_ = dle.front()->state();
451 if(state_ == 0) { // cleared
452 es.erase(dle.front());
453 continue;
454 }else if(state_ == -1){ // conflict equation
455 std::cerr << "# [error] conflict equation " << std::endl;
456 return __LINE__;
457 }else if(state_ == 1){ // finial variant
458 const equation<T> & eq = *dle.front();
459 const T &value_ = eq.value();
460 // prime index's coefficient should be 1
461 assert(fabs(eq.e_vec_.front().coefficient - 1) < 1e-6);
462 const size_t index = eq.get_prim_idx();
463 if(node_flag_[index]){
464 if(fabs(nodes_[index] - value_) > 1e-6){
465 std::cerr << "# [error] conficts happen, node " << index
466 << " has different value " << nodes_[index] << ","
467 << value_ << std::endl;
468 return __LINE__;
469 }
470 }else{
471 // update corresponding equations with the new node value
472 nodes_[index] = value_;
473 node_flag_[index] = true;
474 std::list<equation_ptr> & node_linked_eq = idx2equation_[index];
475 for(typename std::list<equation_ptr>::iterator leqit =
476 node_linked_eq.begin(); leqit != node_linked_eq.end(); ++leqit){
477 equation<T> & eq = *(*leqit);
478 eq.update(index, nodes_[index]);
479 eq.standardization();
480 }
481 node_linked_eq.clear();
482 prime_idx2equation_.erase(ptr);
483 is_modified = true;
484 }
485 }
486 ++ptr;
487 }else{
488 assert(dle.size() > 1);
489 // this prime_index point to several equations,
490 // which sould be eliminated
491 typename std::list<equation_ptr>::iterator begin = dle.begin();
492 typename std::list<equation_ptr>::iterator first = begin++;
493 // to keep each prime index linked only one equation
494 for(typename std::list<equation_ptr>::iterator next = begin;
495 next != dle.end();){
496 // to eliminate the prime index, each equation minus the first one
497 *(*next) -= *(*first);
498 (*next)->standardization();
499 const size_t prim_index = (*next)->get_prim_idx();
500 assert(prim_index >= (*first)->get_prim_idx());
501 prime_idx2equation_[prim_index].push_back(*next);
502 dle.erase(next++);
503 }
504 ++ptr;
505 } // end else
506 }
507 if(!is_modified) break;
508 }
509 return 0;
510 }
511 }
512 }
513 #endif // GAUSS_ELIMINATION_H
下面有个简单的测试例子:
//2x2+2x1+2x1=2
//x1=1
//5x1+x2=2
1 int main()
2 { typedef double val_type;
3 matrix<val_type> node = zeros<val_type>(4,1);
4 vector<bool> node_flag(4,false);
5
6 jtf::algorithm::gauss_eliminator<val_type> ge(node, node_flag);
7
8 {
9 equation<val_type> eq;
10 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1),
11 static_cast<val_type>(2)));
12 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
13 static_cast<val_type>(2)));
14 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
15 static_cast<val_type>(2)));
16 eq.value() = 2;
17 ge.add_equation(eq);
18 }
19 {
20 equation<val_type> eq;
21 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
22 static_cast<val_type>(1)));
23 eq.value() = 1;
24 ge.add_equation(eq);
25 }
26 {
27 equation<val_type> eq;
28 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
29 static_cast<val_type>(5)));
30
31 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1),
32 static_cast<val_type>(1)));
33
34 eq.value() = 2;
35 ge.add_equation(eq);
36 }
37 for(size_t t = 0; t < node_flag.size(); ++t){
38 if(node_flag[t] == true){
39 cerr << "# node " << t << " = " << node[t] << endl;
40 }else
41 cerr << "# node " << t << " unknown." << endl;
42 }
43 return 0;
44 }
这段code在g++4.6.3上编译通过,如果有更加方便简单的方法,请大家指教~