数据结构:稀疏矩阵的简单运算

最近专业课比较多,没怎么看数据结构了。今天有点时间写了一下数据结构中稀疏矩阵的运算。写过之后感觉代码的构思不太好。这里先发一下,

假期的时候再重新实现一下。

SpMatrix头文件

  1 //The sparse matrix's header file
  2 //the class  SpMatrix
  3 //Author: 'Karllen'
  4 //Date:05/28/2014
  5 
  6 #ifndef _SpMatrix_H_
  7 #define _SPMatrix_H_
  8 #include <iostream>
  9 template<class T>
 10 class Triple
 11 {
 12 public:
 13     T ele;      // The Triple  element
 14     int i,j;    // The element location in the matrix
 15 };
 16 
 17 template<class T>
 18 class SpMatrix
 19 {
 20 public:
 21     SpMatrix(void);
 22     SpMatrix(T **inptr,int mrow,int mcolumn);                        //get the matrix from inptr
 23     ~SpMatrix(void);
 24     void transposeSpMatrix(SpMatrix& osp);                           //transpose matrix
 25     void addSpMatrix(const SpMatrix& lhs,const SpMatrix& rhs);       //add two sparse matrix
 26     void subSpMatrix(const SpMatrix& lhs,const SpMatrix& rhs);       //sub two sparse matrix
 27     void mulSpMatrix(const SpMatrix& lhs,const SpMatrix& rhs);       //multiplication two sparse matrix
 28     void divSpMatrix(const SpMatrix& lhs,const SpMatrix& rhs);       //divide two sparse matrix
 29     void sptoprintMatxi()const;                                      //print the sparse matrix as matrix
 30     void printMatrix()const;                                         //print the sparse matrix 
 31 private:
 32     Triple<T> *tri;                                                  //the pointer to sparse matrix
 33     int   row,column;                                                //the sparse matrix's row and column 
 34     int   count;                                                     //the total numbers of triple
 35 };
 36 
 37 template<class T>
 38 SpMatrix<T>::SpMatrix(void):row(0),column(0),count(0),tri(NULL)
 39 {
 40 
 41 }
 42 
 43 template<class T>
 44 SpMatrix<T>::SpMatrix(T **inptr,int mrow,int mcolumn)
 45 {
 46     count  = 0;
 47     row    = mrow;
 48     column = mcolumn;
 49     tri    = new Triple<T>[mrow*mcolumn]; 
 50     for (int i = 0;i<row;++i)
 51     {
 52         for (int j = 0;j<column;++j)
 53         {
 54             if (0 !=inptr[i][j])
 55             {
 56                 tri[count].i = i;
 57                 tri[count].j = j;
 58                 tri[count].ele = inptr[i][j];
 59                 ++count;
 60             }
 61         }
 62     }  
 63 }
 64 
 65 template<class T>
 66 SpMatrix<T>::~SpMatrix(void)
 67 {
 68     if (tri!=NULL)
 69     {
 70         delete []tri;
 71     }
 72 }
 73 
 74 template<class T>
 75 void SpMatrix<T>::transposeSpMatrix(SpMatrix& osp)
 76 {
 77     if (0!=osp.count)
 78     {
 79         tri = new Triple<T>[osp.count];
 80         row    = osp.column;
 81         column = osp.row;
 82         int k = 0;
 83         for (int i = 0;i<osp.column;++i)
 84         {
 85             for (int j = 0;j<osp.count;++j)
 86             {
 87                 if (i == osp.tri[j].j)
 88                 {
 89                     tri[k].i   = osp.tri[j].j;
 90                     tri[k].j   = osp.tri[j].i;
 91                     tri[k].ele = osp.tri[j].ele;
 92                     ++k;
 93                     ++count;
 94                 }
 95             }
 96         }
 97     }
 98     else
 99     {
100         std::cout<<"三元组个数为0,不能转置"<<std::endl;
101     }
102 }
103 
104 template<class T>
105 void SpMatrix<T>::addSpMatrix(const SpMatrix& lhs,const SpMatrix& rhs)
106 {
107     if (lhs.row==lhs.column&&rhs.row==rhs.column&&lhs.row==rhs.row)
108     {
109         row    = lhs.row;
110         column = lhs.column;
111         tri    = new Triple<T>[lhs.count+rhs.count];
112         int li=0,lj=0,i=0;
113         while(li<lhs.count&&lj<rhs.count)
114         {
115             if (lhs.tri[li].i==rhs.tri[lj].i)
116             {
117                 if (lhs.tri[li].j==rhs.tri[lj].j)
118                 {
119                     if (lhs.tri[li].ele+rhs.tri[lj].ele!=0)
120                     {
121                         tri[i].i   = lhs.tri[li].i;
122                         tri[i].j   = lhs.tri[li].j;
123                         tri[i].ele = lhs.tri[li].ele+rhs.tri[lj].ele;
124                         ++i;
125                         ++count;
126                     }    
127                     ++li;
128                     ++lj;
129                 }
130                 else
131                 {
132                     if (lhs.tri[li].j<rhs.tri[lj].j)
133                     {
134                         tri[i].ele = lhs.tri[li].ele;
135                         tri[i].i   = lhs.tri[li].i;
136                         tri[i].j   = lhs.tri[li].j;
137                         ++i;
138                         ++count;
139                         ++li;
140                     }
141                     else
142                     {
143                         tri[i].ele = rhs.tri[lj].ele;
144                         tri[i].i   = rhs.tri[lj].i;
145                         tri[i].j   = rhs.tri[lj].j;
146                         ++i;
147                         ++count;
148                         ++lj;
149                     }
150                 }
151             }
152             else
153             {
154                 if (lhs.tri[li].i<rhs.tri[lj].i)
155                 {
156                     tri[i].ele = lhs.tri[li].ele;
157                     tri[i].i   = lhs.tri[li].i;
158                     tri[i].j   = lhs.tri[li].j;
159                     ++i;
160                     ++count;
161                     ++li;
162                 }
163                 else
164                 {
165                     tri[i].ele = rhs.tri[lj].ele;
166                     tri[i].i   = rhs.tri[lj].i;
167                     tri[i].j   = rhs.tri[lj].j;
168                     ++i;
169                     ++count;
170                     ++lj;
171                 }
172             }
173 
174         }
175         if (li==lhs.count)
176         {
177             while (lj<rhs.count)
178             {
179                 tri[i].ele = rhs.tri[lj].ele;
180                 tri[i].i   = rhs.tri[lj].i;
181                 tri[i].j   = rhs.tri[lj].j;
182                 ++i;
183                 ++count;
184                 ++lj;
185             }
186         }
187         if(lj==rhs.count&&lhs.count!=li)
188         {
189             while (li<lhs.count)
190             {
191                 tri[i].ele = lhs.tri[li].ele;
192                 tri[i].i   = lhs.tri[li].i;
193                 tri[i].j   = lhs.tri[li].j;
194                 ++i;
195                 ++count;
196                 ++li;
197             }
198         }
199     }
200     else
201     {
202         std::cout<<"Two matrix is different type and aren't  to  add!"<<std::endl;
203     }
204 }
205 
206 template<class T>
207 void SpMatrix<T>::subSpMatrix(const SpMatrix& lhs,const SpMatrix& rhs)
208 {
209     if (lhs.row==lhs.column&&rhs.row==rhs.column&&lhs.row==rhs.row)
210     {
211         row    = lhs.row;
212         column = lhs.column;
213         tri    = new Triple<T>[lhs.count+rhs.count];
214         int li=0,lj=0,i=0;
215         while(li<lhs.count&&lj<rhs.count)
216         {
217             if (lhs.tri[li].i==rhs.tri[lj].i)
218             {
219                 if (lhs.tri[li].j==rhs.tri[lj].j)
220                 {
221                     if (lhs.tri[li].ele-rhs.tri[lj].ele!=0)
222                     {
223                         tri[i].i   = lhs.tri[li].i;
224                         tri[i].j   = lhs.tri[li].j;
225                         tri[i].ele = lhs.tri[li].ele-rhs.tri[lj].ele;
226                         ++i;
227                         ++count;
228                     }    
229                     ++li;
230                     ++lj;
231                 }
232                 else
233                 {
234                     if (lhs.tri[li].j<rhs.tri[lj].j)
235                     {
236                         tri[i].ele = lhs.tri[li].ele;
237                         tri[i].i   = lhs.tri[li].i;
238                         tri[i].j   = lhs.tri[li].j;
239                         ++i;
240                         ++count;
241                         ++li;
242                     }
243                     else
244                     {
245                         tri[i].ele = -rhs.tri[lj].ele;
246                         tri[i].i   = rhs.tri[lj].i;
247                         tri[i].j   = rhs.tri[lj].j;
248                         ++i;
249                         ++count;
250                         ++lj;
251                     }
252                 }
253             }
254             else
255             {
256                 if (lhs.tri[li].i<rhs.tri[lj].i)
257                 {
258                     tri[i].ele = lhs.tri[li].ele;
259                     tri[i].i   = lhs.tri[li].i;
260                     tri[i].j   = lhs.tri[li].j;
261                     ++i;
262                     ++count;
263                     ++li;
264                 }
265                 else
266                 {
267                     tri[i].ele = -rhs.tri[lj].ele;
268                     tri[i].i   = rhs.tri[lj].i;
269                     tri[i].j   = rhs.tri[lj].j;
270                     ++i;
271                     ++count;
272                     ++lj;
273                 }
274             }
275 
276         }
277         if (li==lhs.count)
278         {
279             while (lj<rhs.count)
280             {
281                 tri[i].ele = -rhs.tri[lj].ele;
282                 tri[i].i   = rhs.tri[lj].i;
283                 tri[i].j   = rhs.tri[lj].j;
284                 ++i;
285                 ++count;
286                 ++lj;
287             }
288         }
289         if(lj==rhs.count&&lhs.count!=li)
290         {
291             while (li<lhs.count)
292             {
293                 tri[i].ele = lhs.tri[li].ele;
294                 tri[i].i   = lhs.tri[li].i;
295                 tri[i].j   = lhs.tri[li].j;
296                 ++i;
297                 ++count;
298                 ++li;
299             }
300         }
301     }
302     else
303     {
304         std::cout<<"Two matrix is different type and aren't  to  sub!"<<std::endl;
305     }
306 }
307 
308 
309 template<class T>
310 void SpMatrix<T>::mulSpMatrix(const SpMatrix& lhs,const SpMatrix& rhs)
311 {
312     if (lhs.column==rhs.row&&0!=lhs.count||0!=rhs.count)
313     {
314         row    = lhs.row;
315         column = rhs.column;
316         int _alN = row*column;
317         tri      = new Triple<T>[_alN];
318 
319         int *rhsRowSize  = new int[rhs.row];     //save the rhs the element that's not zero int the row
320         int *rhsRowStart = new int[rhs.row];     //save the first element's location int the sparse matrix
321         int *rhsSzele    = new int[rhs.column];
322         for (int i=0;i<rhs.row;++i)
323         {
324             rhsRowSize[i] = 0;
325         }
326         for (int i=0; i<rhs.count;++i) //calculate the numbers of every row in the rhs
327         {
328             ++rhsRowSize[rhs.tri[i].i];
329         }
330 
331         rhsRowStart[0] = 0;            //calculate the location of first row in the rhs 
332         for (int i=1;i<rhs.row;++i)
333         {
334             rhsRowStart[i] = rhsRowStart[i-1]+rhsRowSize[i-1];
335         }
336 
337         int k = 0;
338         for (int i = 0;i<lhs.row;++i)             //calculate the every row of lhs   
339         {
340             for (int j = 0;j<rhs.column;j++)      //make the all 0
341             {
342                 rhsSzele[j] = 0;
343             }
344             while (k<lhs.count && i==lhs.tri[k].i)  
345             {
346                 int cRlF = lhs.tri[k].j;
347                 for (int i = rhsRowStart[cRlF];  cRlF == rhs.tri[i].i ; ++i)
348                 {
349                     rhsSzele[rhs.tri[i].j]+=lhs.tri[k].ele * rhs.tri[i].ele;
350                 }
351                 ++k;
352             }
353 
354             for (int j = 0 ; j<rhs.column;++j)    //deal the result every i row value
355             {
356                 if (0!=rhsSzele[j])
357                 {
358                     tri[count].i   = i;
359                     tri[count].j   = j;
360                     tri[count].ele = rhsSzele[j];
361                     ++count;
362                 }
363             }
364         }
365         delete [] rhsRowSize;
366         delete [] rhsRowStart;
367         delete [] rhsSzele;
368     }
369     else
370     {
371         std::cout<<"两矩阵不满足相成的条件"<<std::endl;
372     }
373 }
374 
375 template<class T>
376 void SpMatrix<T>::sptoprintMatxi()const
377 {
378     if (0!=count)
379     {
380         int k = 0,i = 0;
381         int rcm = row*column;
382         while(i<rcm)
383         {
384 
385             if (i == tri[k].i*column+tri[k].j)
386             {
387                 std::cout<<tri[k].ele<<" ";
388                 ++k;
389             }
390             else
391             {
392                 std::cout<<0<<" ";
393             }
394             if (0 == (i+1)%column )
395             {
396                 std::cout<<::std::endl;
397             }
398             ++i;
399         }
400          std::cout<<std::endl;
401     }
402     std::cout<<std::endl;
403 }
404 
405 template<class T>
406 void SpMatrix<T>::printMatrix()const
407 {
408     for (int i = 0;i<count;++i)
409     {
410         std::cout<<tri[i].i<<" "<<tri[i].j<<" "<<tri[i].ele<<std::endl;
411     }
412     std::cout<<"三元组的个数为:"<<count<<std::endl;
413     std::cout<<std::endl;
414 }
415 
416 
417 #endif //_SpMatrix_H_

menu.h测试头文件代码

  1 #ifndef _KARMENU_H_
  2 #define _KARMENU_H_
  3 #include "SpMatrix.h"
  4 #include<stdlib.h>
  5 void transMatrix()
  6 {
  7     int **mptr1 = NULL;
  8     int row1,column1;
  9     std::cout<<"Enter the matrix's row ";
 10     std::cin>>row1;
 11     std::cout<<"Enter the matrix's column ";
 12     std::cin>>column1;
 13 
 14     mptr1 = new int *[row1];
 15     for (int i=0;i<row1;++i)
 16     {
 17         mptr1[i]= new int[column1];
 18         for (int j=0;j<column1;++j)
 19         {
 20             std::cin>>mptr1[i][j];
 21         }
 22     }
 23     SpMatrix<int> spmat1(mptr1,row1,column1);
 24     SpMatrix<int> spmat2;
 25     spmat2.transposeSpMatrix(spmat1);
 26     std::cout<<std::endl;
 27     std::cout<<"转置矩阵为:"<<std::endl;
 28     spmat2.sptoprintMatxi();
 29     std::cout<<"转置矩阵的三元组表示为:"<<std::endl;
 30     spmat2.printMatrix();
 31 }
 32 void addMatrixM()
 33 {
 34     int **mptr1 = NULL;
 35     int row1,column1;
 36     std::cout<<"Enter the matrix's row ";
 37     std::cin>>row1;
 38     std::cout<<"Enter the matrix's column ";
 39     std::cin>>column1;
 40 
 41     mptr1 = new int *[row1];
 42     for (int i=0;i<row1;++i)
 43     {
 44         mptr1[i]= new int[column1];
 45         for (int j=0;j<column1;++j)
 46         {
 47             std::cin>>mptr1[i][j];
 48         }
 49     }
 50     int **mptr2 = NULL;
 51     int row2,column2;
 52     std::cout<<"Enter the matrix's row ";
 53     std::cin>>row2;
 54     std::cout<<"Enter the matrix's column ";
 55     std::cin>>column2;
 56 
 57     mptr2 = new int *[row2];
 58     for (int i=0;i<row2;++i)
 59     {
 60         mptr2[i]= new int[column2];
 61         for (int j=0;j<column2;++j)
 62         {
 63             std::cin>>mptr2[i][j];
 64         }
 65     }
 66     if (row1==column1&&row2==column2&&row1==row2)
 67     {
 68         SpMatrix<int> spmat1(mptr1,row1,column1);
 69         SpMatrix<int> spmat2(mptr2,row2,column2);
 70         SpMatrix<int> spmat3;
 71         spmat3.addSpMatrix(spmat1,spmat2);
 72         std::cout<<std::endl;
 73         std::cout<<"被加矩阵的三元组表示:"<<std::endl;
 74         spmat1.printMatrix();
 75         std::cout<<"加数矩阵的三元组表示:"<<std::endl;
 76         spmat2.printMatrix();
 77 
 78         std::cout<<"所求和矩阵的普通表示:"<<std::endl;
 79         spmat3.sptoprintMatxi();
 80         std::cout<<"所求和矩阵的三元组表示:"<<std::endl;
 81         spmat3.printMatrix(); 
 82     }
 83     else
 84     {
 85         std::cout<<"输入矩阵的阶数不符合要求,请输入等阶矩阵操作"<<std::endl;
 86     }
 87 
 88     for (int i = 0;i<row1;++i)
 89     {
 90         delete  [] mptr1[i];
 91     }
 92     delete []mptr1;
 93 
 94     for (int i = 0;i<row2;++i)
 95     {
 96         delete  [] mptr2[i];
 97     }
 98     delete []mptr2;
 99     mptr2 = NULL;
100 }
101 
102 void subMatrixM()
103 {
104     int **mptr1 = NULL;
105     int row1,column1;
106     std::cout<<"Enter the matrix's row ";
107     std::cin>>row1;
108     std::cout<<"Enter the matrix's column ";
109     std::cin>>column1;
110 
111     mptr1 = new int *[row1];
112     for (int i=0;i<row1;++i)
113     {
114         mptr1[i]= new int[column1];
115         for (int j=0;j<column1;++j)
116         {
117             std::cin>>mptr1[i][j];
118         }
119     }
120     int **mptr2 = NULL;
121     int row2,column2;
122     std::cout<<"Enter the matrix's row ";
123     std::cin>>row2;
124     std::cout<<"Enter the matrix's column ";
125     std::cin>>column2;
126 
127     mptr2 = new int *[row2];
128     for (int i=0;i<row2;++i)
129     {
130         mptr2[i]= new int[column2];
131         for (int j=0;j<column2;++j)
132         {
133             std::cin>>mptr2[i][j];
134         }
135     }
136     if(row1==column1&&row2==column2&&row1==row2)
137     {
138         SpMatrix<int> spmat1(mptr1,row1,column1);
139         SpMatrix<int> spmat2(mptr2,row2,column2);
140         SpMatrix<int> spmat3;
141         spmat3.subSpMatrix(spmat1,spmat2);
142         std::cout<<std::endl;
143         std::cout<<"被减矩阵的三元组表示:"<<std::endl;
144         spmat1.printMatrix();
145         std::cout<<"减数矩阵的三元组表示:"<<std::endl;
146         spmat2.printMatrix();
147 
148         std::cout<<"所求差矩阵的普通表示:"<<std::endl;
149         spmat3.sptoprintMatxi();
150         std::cout<<"所求差矩阵的三元组表示:"<<std::endl;
151         spmat3.printMatrix();
152     }
153     else
154     {
155         std::cout<<"您输入的矩阵不符合要求,求输入等阶矩阵"<<std::endl;
156     }
157     for (int i = 0;i<row1;++i)
158     {
159         delete  [] mptr1[i];
160     }
161     delete []mptr1;
162 
163     for (int i = 0;i<row2;++i)
164     {
165         delete  [] mptr2[i];
166     }
167     delete []mptr2;
168 }
169 
170 void mulMatrixM()
171 {
172     int **mptr1 = NULL;
173     int row1,column1;
174     std::cout<<"Enter the matrix's row ";
175     std::cin>>row1;
176     std::cout<<"Enter the matrix's column ";
177     std::cin>>column1;
178 
179     mptr1 = new int *[row1];
180     for (int i=0;i<row1;++i)
181     {
182         mptr1[i]= new int[column1];
183         for (int j=0;j<column1;++j)
184         {
185             std::cin>>mptr1[i][j];
186         }
187     }
188     int **mptr2 = NULL;
189     int row2,column2;
190     std::cout<<"Enter the matrix's row ";
191     std::cin>>row2;
192     std::cout<<"Enter the matrix's column ";
193     std::cin>>column2;
194 
195     mptr2 = new int *[row2];
196     for (int i=0;i<row2;++i)
197     {
198         mptr2[i]= new int[column2];
199         for (int j=0;j<column2;++j)
200         {
201             std::cin>>mptr2[i][j];
202         }
203     }
204     if(column1 == row2)
205     {
206         SpMatrix<int> spmat1(mptr1,row1,column1);
207         SpMatrix<int> spmat2(mptr2,row2,column2);
208         SpMatrix<int> spmat3;
209         spmat3.mulSpMatrix(spmat1,spmat2);
210         std::cout<<std::endl;
211         std::cout<<"被乘矩阵的三元组表示:"<<std::endl;
212         spmat1.printMatrix();
213         std::cout<<"乘数矩阵的三元组表示:"<<std::endl;
214         spmat2.printMatrix();
215 
216         std::cout<<"所求积矩阵的普通表示:"<<std::endl;
217         spmat3.sptoprintMatxi();
218         std::cout<<"所求积矩阵的三元组表示:"<<std::endl;
219         spmat3.printMatrix();
220     }
221     else
222     {
223         std::cout<<"您输入的矩阵不符合要求,请输入符合符合阶数要求的矩阵"<<std::endl;
224 
225     }
226     for (int i = 0;i<row1;++i)
227     {
228         delete  [] mptr1[i];
229     }
230     delete []mptr1;
231 
232     for (int i = 0;i<row2;++i)
233     {
234         delete  [] mptr2[i];
235     }
236     delete []mptr2;
237 
238 }
239 
240 void divMatrixM()
241 {
242 
243 }
244 
245 int menu()
246 {
247     int zmyn;
248     int choose = 0 ;
249     char a[10]={0};
250     do
251     {
252         std::cout<<"测试矩阵的运算"<<std::endl;
253         std::cout<<"1:   矩阵转置"<<std::endl;
254         std::cout<<"2:   矩阵求和"<<std::endl;
255         std::cout<<"3:  矩阵求差"<<std::endl;
256         std::cout<<"4:   矩阵求积"<<std::endl;
257         std::cout<<"5:   退出运算"<<std::endl;
258         std::cout<<"请输入1--5的选项!"<<std::endl;
259         std::cin>>a;
260         choose = atoi(a);
261         while (choose<1||choose>5)
262         {
263             std::cout<<"您输入的选项不和要求,请输入1--5之间的数字进行操作"<<std::endl;
264             std::cin>>a;
265             choose = atoi(a);
266         }
267         switch(choose)
268         {
269         case 1:  
270             transMatrix();break;
271         case 2:   
272             addMatrixM();break;
273         case 3:   
274             subMatrixM();break;
275         case 4:
276             mulMatrixM();break;
277         case 5:   
278             return 0;   
279         default: ;
280         }
281         std::cout<<"输入1继续操作,输入0或者其他停止操作"  <<std::endl;
282         std::cin>>a;
283         zmyn = atoi(a);
284         while (zmyn<0||zmyn>9)
285         {
286             std::cout<<"请正确输入"<<std::endl;
287             std::cin>>a;
288             zmyn = atoi(a);
289         }
290         system("cls");
291     }while(1 == zmyn);
292     
293     return 0;
294 }
295 #endif //end _KARMENU_H_

main函数入口:

1 #include "KarMenu.h"
2 //static const int ROW = 10;
3 //static const int COLUMN = 10;
4 int main(void)
5 {
6     return menu();
7 }

测试:

矩阵转置:

 

矩阵求和:

 

矩阵求差:

 

矩阵求积:

 测试均正确通过

posted @ 2014-05-28 23:19  karllen  阅读(907)  评论(0编辑  收藏  举报