闲人草堂

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

项目中需要求解动态增长的线性方程组,方程个数会不断增加,需要使用高斯消元法来求解,每个方程添加完后,已知的变量值。

代码中分了三个类,表达式expression,方程 equation, 高斯消元求解器gauss_eliminator。

这段代码还是有问题,如果需要最新的,请访问我的github  https://github.com/tfjiang/gauss_elimination

  1 #ifndef GAUSS_ELIMINATION_H
  2 #define GAUSS_ELIMINATION_H
  3 
  4 #include <map>
  5 #include <vector>
  6 #include <list>
  7 #include <deque>
  9 #include <iostream>
 10 #include <boost/unordered_map.hpp>
 11 #include <boost/static_assert.hpp>
 12 #include <boost/type_traits/is_same.hpp>
 13 #include <zjucad/matrix/matrix.h>  // a private matrix library
 14 
 15 namespace jtf{
 16 namespace algorithm{
 17 
 18 /**
 19  * @brief store an expression A * i, where A is the coeffieient
 20  *
 21  */
 22 template <typename T>
 23 class expression{
 24 public:
 25   expression():index(-1),coefficient(0){}
 26   expression(const size_t & index_, const T & coefficient_)
 27     : index(index_), coefficient(coefficient_){}
 28   size_t index;
 29   T coefficient;
 30   /**
 31    * @brief operator < is used to determin which expression ha
 32    *
 33    * @param other input other expression
 34    * @return bool
 35    */
 36   bool operator < (const expression<T> & b) const{
 37     return index < b.index;
 38   }
 39 
 40   /**
 41    * @brief To determin whether coefficent of this expression is zeros
 42    *
 43    * @return bool
 44    */
 45   bool is_zero() const{ return fabs(coefficient) < 1e-6;}
 46 };
 47 
 48 /**
 49  * @brief make expression with node_idx and value
 50  *
 51  * @param node_idx
 52  * @param value
 53  * @return expression<T>
 54  */
 55 template <typename T>
 56 expression<T> make_expression(const size_t & node_idx, const T & value)
 57 {
 58   expression<T> temp(node_idx, value);
 59   return temp;
 60 }
 61 
 62 /**
 63  * @brief equation class
 64  *
 65  */
 66 template <typename T>
 67 class equation{
 68 public:
 69   typedef typename std::list<expression<T> >::const_iterator eq_const_iterator;
 70   typedef typename std::list<expression<T> >::iterator eq_iterator;
 71 
 72   equation():value_(static_cast<T>(0)){}
 73 
 74   eq_const_iterator begin() const {return e_vec_.begin();}
 75   eq_iterator begin(){return e_vec_.begin();}
 76 
 77   eq_const_iterator end() const {return e_vec_.end();}
 78   eq_iterator end(){return e_vec_.end();}
 79 
 80   /**
 81    * @brief used to standardizate the equation, sort the expressions and merge
 82    *        similar items and normalization to make the first item's coefficient
 83    *        equal 1
 84    *
 85    * @return int
 86    */
 87   int standardization(){
 88     sort_equation();
 89     merge_similar();
 90     normalization();
 91     return 0;
 92   }
 93 
 94   /**
 95    * @brief update the equation with given node idx and value
 96    *
 97    * @param node_idx input node idx
 98    * @param node_value input node value
 99    * @return int
100    */
101   int update(const size_t & node_idx, const  T & node_value);
102   /**
103    * @brief sort equation accronding to the expressions
104    *
105    * @return int
106    */
107   int sort_equation(){
108     e_vec_.sort();
109     return 0;
110   }
111 
112   /**
113    * @brief merge similar items, WARNING. this function should be used after
114    *        sorting.
115    *
116    * @return int
117    */
118   int merge_similar();
119 
120   /**
121    * @brief normalize the equations: to make the first expression's coefficient
122    *        equals 1
123    *
124    * @return int
125    */
126   int normalization();
127 
128   /**
129    * @brief get the equation value
130    *
131    * @return const T
132    */
133   const T& get_value() const {return value_;}
134 
135   /**
136    * @brief get the state of equation:
137    *            if there are no expressions,
138    *               if value != 0 return -1; error equation
139    *               else return 0; cleared
140    *            else
141    *               if there is one expression return 1; calculated finial node
142    *               else return 2; not finished
143    *
144    * @return int
145    */
146   int state() const;
147 
148 
149   /**
150    * @brief define the minus operation
151    *
152    * @param b
153    * @return equation<T>
154    */
155   equation<T> & operator -= (const equation<T> & b);
156 
157   /**
158      * @brief define the output operation
159      *
160      * @param eq
161      * @return equation<T>
162      */
163   friend std::ostream& operator << (std::ostream &output,
164                              const equation<T> &eq)
165   {
166     if(eq.e_vec_.empty()){
167       output << "# ------ empty expressions with value = " << eq.value() << std::endl;
168     }else{
169       output << "# ------ SUM coeff * index , val" << std::endl
170            << "# ------ ";
171       for(equation<T>::eq_const_iterator eqcit = eq.begin(); eqcit != eq.end();){
172         const expression<T> & exp = *eqcit;
173         output << exp.coefficient << "*X" << exp.index;
174         ++eqcit;
175         if(eqcit == eq.end())
176           output << " = ";
177         else
178           output << " + ";
179       }
180       output << eq.value() << std::endl;
181     }
182     return output;
183   }
184 
185 
186   /**
187    * @brief get the first expression idx
188    *
189    * @return size_t
190    */
191   size_t get_prim_idx() const {
192     return e_vec_.front().index;
193   }
194 
195   int add_expression(const expression<T> & exp){
196     e_vec_.push_back(exp);
197     return 0;
198   }
199   T& value() {return value_;}
200   const T& value() const {return value_;}
201   std::list<expression<T> > e_vec_;
202 private:
203   T value_;
204 };
205 
206 template <typename T>
207 int equation<T>::merge_similar(){
208   typedef typename std::list<expression<T> >::iterator leit;
209   leit current = e_vec_.begin();
210   leit next = current;
211   ++next;
212   while(next != e_vec_.end()){
213     if(next->index > current->index) {
214       current = next++;
215       continue;
216     }else if(next->index == current->index){
217       current->coefficient += next->coefficient;
218       e_vec_.erase(next++);
219       if(fabs(current->coefficient) < 1e-6){
220         e_vec_.erase(current++);
221         ++next;
222       }
223     }else{
224       std::cerr << "# [error] merge similar function should only be called "
225                 << "after sorting." << std::endl;
226       return __LINE__;
227     }
228   }
229   return 0;
230 }
231 
232 template <typename T>
233 int equation<T>::normalization()
234 {
235   const T coeff = e_vec_.front().coefficient;
236   if(fabs(coeff) < 1e-6){
237     if(e_vec_.empty()){
238       //std::cerr << "# [info] this equation is empty." << std::endl;
239       return 0;
240     }else
241       std::cerr << "# [error] this expression should be removed." << std::endl;
242     return __LINE__;
243   }
244 
245   value() /= coeff;
246   for(typename std::list<expression<T> >::iterator lit = e_vec_.begin();
247       lit != e_vec_.end(); ++lit){
248     expression<T> & ep = *lit;
249     ep.coefficient /= coeff;
250   }
251   return 0;
252 }
253 
254 template <typename T>
255 int equation<T>::state() const
256 {
257   if(e_vec_.empty()){
258     if(fabs(value()) < 1e-8)
259       return 0; // is cleared
260     return -1; // is conflicted
261   }else{
262     if(e_vec_.size() == 1)
263       return 1; // finial variant
264     else
265       return 2; // not ready
266   }
267 }
268 
269 template <typename T>
270 equation<T> & equation<T>::operator -= (const equation<T> & b )
271 {
272   if(&b == this) {
273     e_vec_.clear();
274     value() = 0;
275     return *this;
276   }
277 
278   for(typename std::list<expression<T> >::const_iterator lecit_b =
279       b.e_vec_.begin(); lecit_b != b.e_vec_.end(); ++lecit_b){
280     const expression<T> & exp = *lecit_b;
281     const size_t &node_idx = exp.index;
282     assert(fabs(exp.coefficient) > 1e-6);
283     bool is_found = false;
284 
285     for(typename std::list<expression<T> >::iterator leit_a = e_vec_.begin();
286         leit_a != e_vec_.end(); ++leit_a){
287       expression<T> & exp_a = *leit_a;
288       if(exp_a.index == node_idx){
289         exp_a.coefficient -= exp.coefficient;
290         // zeros
291         if(fabs(exp_a.coefficient) < 1e-6) {
292           e_vec_.erase(leit_a);
293           is_found = true;
294           break;
295         }
296       }
297     }
298     if(!is_found){
299       e_vec_.push_back(make_expression(node_idx, -1 * exp.coefficient));
300     }
301   }
302 
303   value() -= b.value();
304   sort_equation();
305   normalization();
306 
307   return *this;
308 }
309 
310 template <typename T>
311 int equation<T>::update(const size_t & node_idx, const  T & node_value)
312 {
313   for(typename std::list<expression<T> >::iterator leit = e_vec_.begin();
314       leit != e_vec_.end();){
315     expression<T> & exp = *leit;
316     if(exp.index == node_idx){
317       value() -= exp.coefficient * node_value;
318       e_vec_.erase(leit++);
319     }
320     ++leit;
321   }
322   return 0;
323 }
324 
325 //! @brief this class only handle Ai+Bi=Ci
326 template <typename T>
327 class gauss_eliminator{
328   BOOST_MPL_ASSERT_MSG((boost::is_same<T,double>::value ) ||
329                        (boost::is_same<T,float>::value ),
330                        NON_FLOAT_TYPES_ARE_NOT_ALLOWED, (void));
331 public:
332   /**
333  * @brief construct gauss_eliminator class
334  *
335  * @param nodes input nodes
336  * @param node_flag input node_flag which will be tagged as true if the
337  *        corresponding node is known
338  */
339   gauss_eliminator(zjucad::matrix::matrix<T> & nodes,
340                    std::vector<bool> & node_flag)
341     :nodes_(nodes), node_flag_(node_flag){
342     idx2equation_.resize(nodes_.size());
343   }
344 
345   /**
346    * @brief add equation to gauss_eliminator, every time an equation is added,
347    *        eliminate function is called.
348    *
349    * @param input equation
350    * @return int
351    */
352   int add_equation(const equation<T> & e);
353 
354   /**
355    * @brief This function will start to eliminate equations above all added equations
356    *
357    * @return int return 0 if works fine, or return non-zeros
358    */
359   int eliminate();
360 
361 
362   /**
363    * @brief update the equation, it will check all variant, if a variant is
364    *        already known, update this equation
365    *
366    * @param eq input equation
367    * @return int return 0 if nothing changes, or retunr 1;
368    */
369   int update_equation(equation<T> & eq);
370 
371 private:
372   zjucad::matrix::matrix<T> & nodes_;
373   std::vector<bool> & node_flag_;
374   std::list<equation<T> > es;
375 
376   typedef typename std::list<equation<T> >::iterator equation_ptr;
377   std::vector<std::list<equation_ptr> > idx2equation_;
378 
379   // this map store the smallest expression
380   typedef typename std::map<size_t, std::list<equation_ptr> >::iterator prime_eq_ptr;
381   std::map<size_t, std::list<equation_ptr> > prime_idx2equation_;
382 };
383 
384 template <typename T>
385 int gauss_eliminator<T>::add_equation(const equation<T> & e){
386   es.push_back(e);
387   equation<T> & e_back = es.back();
388   for(typename equation<T>::eq_iterator eit = e_back.begin(); eit != e_back.end(); ){
389     if(node_flag_[eit->index]){
390       e_back.value() -= nodes_[eit->index] * eit->coefficient;
391       e_back.e_vec_.erase(eit++);
392     }else
393       ++eit;
394   }
395   if(e_back.state() == 0){// this equation is cleared
396     es.pop_back();
397     return 0;
398   }else if(e_back.state() == -1){
399     std::cerr << "# [error] strange conflict equation: " << std::endl;
400     std::cerr << e;
401     es.pop_back();
402     return __LINE__;
403   }
404   e_back.standardization();
405 
406   for(typename equation<T>::eq_const_iterator it = e.begin();
407       it != e.end(); ++it){
408     equation_ptr end_ptr = es.end();
409     idx2equation_[it->index].push_back(--end_ptr);
410   }
411 
412   equation_ptr end_ptr = es.end();
413   prime_idx2equation_[e_back.get_prim_idx()].push_back(--end_ptr);
414   eliminate();
415   return 0;
416 }
417 
418 template <typename T>
419 int gauss_eliminator<T>::update_equation(equation<T> & eq)
420 {
421   for(typename std::list<expression<T> >::iterator it = eq.e_vec_.begin();
422       it != eq.e_vec_.end(); ){
423     const expression<T> & exp = *it;
424     if(node_flag_[exp.index]){
425       eq.value() -= exp.coefficient * nodes_[exp.index];
426       eq.e_vec_.erase(it++);
427     }else
428       ++it;
429   }
430   eq.standardization();
431   return 0;
432 }
433 
434 template <typename T>
435 int gauss_eliminator<T>::eliminate()
436 {
437   std::cerr << std::endl;
438   while(1){
439     bool is_modified = false;
440 
441     for(prime_eq_ptr ptr = prime_idx2equation_.begin();
442         ptr != prime_idx2equation_.end();)
443     {
444       std::list<equation_ptr> & dle = ptr->second;
445       if(dle.empty()) {
446         prime_idx2equation_.erase(ptr++);
447         is_modified = true;
448         continue;
449       }else if(dle.size() == 1){ // contain only one equation
450         const int state_ = dle.front()->state();
451         if(state_ == 0) { // cleared
452           es.erase(dle.front());
453           continue;
454         }else if(state_ == -1){ // conflict equation
455           std::cerr << "# [error] conflict equation " << std::endl;
456           return __LINE__;
457         }else if(state_ == 1){ // finial variant
458           const equation<T> & eq = *dle.front();
459           const T &value_ = eq.value();
460           // prime index's coefficient should be 1
461           assert(fabs(eq.e_vec_.front().coefficient - 1) < 1e-6);
462           const size_t index = eq.get_prim_idx();
463           if(node_flag_[index]){
464             if(fabs(nodes_[index] - value_) > 1e-6){
465               std::cerr << "# [error] conficts happen, node " << index
466                         << " has different value " << nodes_[index] << ","
467                         << value_ << std::endl;
468               return __LINE__;
469             }
470           }else{
471             // update corresponding equations with the new node value
472             nodes_[index] = value_;
473             node_flag_[index] = true;
474             std::list<equation_ptr> & node_linked_eq = idx2equation_[index];
475             for(typename std::list<equation_ptr>::iterator leqit =
476                 node_linked_eq.begin(); leqit != node_linked_eq.end(); ++leqit){
477               equation<T> & eq = *(*leqit);
478               eq.update(index, nodes_[index]);
479               eq.standardization();
480             }
481             node_linked_eq.clear();
482             prime_idx2equation_.erase(ptr);
483             is_modified = true;
484           }
485         }
486         ++ptr;
487       }else{
488         assert(dle.size() > 1);
489         // this prime_index point to several equations,
490         // which sould be eliminated
491         typename std::list<equation_ptr>::iterator begin = dle.begin();
492         typename std::list<equation_ptr>::iterator first = begin++;
493         // to keep each prime index linked only one equation
494         for(typename std::list<equation_ptr>::iterator next = begin;
495             next != dle.end();){
496           // to eliminate the prime index, each equation minus the first one
497           *(*next) -= *(*first);
498           (*next)->standardization();
499           const size_t prim_index = (*next)->get_prim_idx();
500           assert(prim_index >= (*first)->get_prim_idx());
501           prime_idx2equation_[prim_index].push_back(*next);
502           dle.erase(next++);
503         }
504         ++ptr;
505       } // end else
506     }
507     if(!is_modified) break;
508   }
509   return 0;
510 }
511 }
512 }
513 #endif // GAUSS_ELIMINATION_H

 

下面有个简单的测试例子:

//2x2+2x1+2x1=2
//x1=1
//5x1+x2=2
 1  int main()
 2 { typedef double val_type;
 3   matrix<val_type> node = zeros<val_type>(4,1);
 4   vector<bool> node_flag(4,false);
 5 
 6   jtf::algorithm::gauss_eliminator<val_type> ge(node, node_flag);
 7 
 8   {
 9     equation<val_type> eq;
10     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1),
11                                                       static_cast<val_type>(2)));
12     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
13                                                       static_cast<val_type>(2)));
14     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
15                                                       static_cast<val_type>(2)));
16     eq.value() = 2;
17     ge.add_equation(eq);
18   }
19   {
20     equation<val_type> eq;
21     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
22                                                       static_cast<val_type>(1)));
23     eq.value() = 1;
24     ge.add_equation(eq);
25   }
26   {
27     equation<val_type> eq;
28     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0),
29                                                       static_cast<val_type>(5)));
30 
31     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1),
32                                                       static_cast<val_type>(1)));
33 
34     eq.value() = 2;
35     ge.add_equation(eq);
36   }
37   for(size_t t = 0; t < node_flag.size(); ++t){
38     if(node_flag[t] == true){
39       cerr << "# node " << t << " = " << node[t] << endl;
40     }else
41       cerr << "# node " << t << " unknown." << endl;
42   }
43 return 0;
44 }

这段code在g++4.6.3上编译通过,如果有更加方便简单的方法,请大家指教~

posted on 2012-09-05 10:09  闲人草堂  阅读(1996)  评论(0编辑  收藏  举报