//cross.h
#include <iostream>
using namespace std;
struct Node
{
short row;
short col;
int num;
Node * right; //同一行的下一个元素
Node * down; //同一列的下一个元素
Node()
{
row = col = -1;
num = 0;
right = down = NULL;
}
Node(short _row, short _col, int _val)
{
row = _row;
col = _col;
num = _val;
right = down = NULL;
}
};
class Cross
{
private:
int m_row; //行数
int m_col; //列数
Node ** pRow;
Node ** pCol;
Cross(){}
public:
Cross(int row, int col);
int GetRow() const
{
return m_row;
}
int GetCol() const
{
return m_col;
}
~Cross();
void GetChainFromMatrix(const int * arr);
void Add(int row, int col, int num);
void Multiplication(const Cross & multiplicator, Cross & result);
void GetMatrix(int * arr);
void Show();
};
//cross.cpp
#include "cross.h"
Cross::Cross(int row, int col):m_row(row), m_col(col)
{
pRow = new Node*[m_row];
for(int i = 0; i < m_row; ++i)
{
pRow[i] = NULL;
}
pCol = new Node*[m_col];
for(/*int */i = 0; i < m_col; ++i)
{
pCol[i] = NULL;
}
}
Cross::~Cross()
{
for(int i = 0; i < m_row; ++i)
{
Node * temp = pRow[i];
while(temp)
{
pRow[i] = temp->right;
delete temp;
temp = pRow[i];
}
}
delete [] pRow;
delete [] pCol;
}
//此处arr的维数应和m_row,m_col相同,可通过GetRow,GetCol获得
void Cross::GetChainFromMatrix(const int * arr)
{
for(int i = 0; i < m_row; ++i)
{
for(int j = 0; j < m_col; ++j)
{
if(arr[i * m_row + j] == 0) //如果是零元,则继续
continue;
Node * add = new Node(i, j, arr[i * m_row + j]);
if(pRow[i])
{
Node * temp = pRow[i]; //链接到行上
while(temp)
{
//此处排序,若按从左至右,从上至下的顺序进行,则已有序
/*if(temp->right &&
temp->right->col > add->col)
{
add->right = temp->right;
temp->right = add;
break;
}*/
if(temp->right == NULL)
{
temp->right = add;
break;
}
temp = temp->right;
}
}
else
pRow[i] = add;
if(pCol[j])
{
Node * temp = pCol[j]; //链接到列上
while(temp)
{
//此处排序,若按从左至右,从上至下的顺序进行,则已有序
/*if(temp->down &&
temp->down->row > add->row)
{
add->down = temp->down;
temp->down = add;
break;
}*/
if(temp->down == NULL)
{
temp->down = add;
break;
}
temp = temp->down;
}
}
else
pCol[j] = add;
}
}
}
void Cross::Add(int row, int col, int num)
{
Node * add = new Node(row, col, num);
if(pRow[row])
{
Node * temp = pRow[row]; //链接到行上
while(temp)
{
//此处排序,若按从左至右,从上至下的顺序进行,则已有序
if(temp->right && temp->right->col >
add->col)
{
add->right = temp->right;
temp->right = add;
break;
}
if(temp->right == NULL)
{
temp->right = add;
break;
}
temp = temp->right;
}
}
else
pRow[row] = add;
if(pCol[col])
{
Node * temp = pCol[col]; //链接到列上
while(temp)
{
//此处排序,若按从左至右,从上至下的顺序进行,则已有序
if(temp->down && temp->down->row >
add->row)
{
add->down = temp->down;
temp->down = add;
break;
}
if(temp->down == NULL)
{
temp->down = add;
break;
}
temp = temp->down;
}
}
else
pCol[col] = add;
}
void Cross::Multiplication(const Cross & multiplicator, Cross &
crossResult)//外部确保此处crossResult中没有数据
{
for(int i = 0; i < m_row; ++i)
{
for(int j = 0; j < multiplicator.m_col; ++j)
{
int count = 0; //crossResult中(i,
j)位置的值
Node * pR = pRow[i];
Node * pC = multiplicator.pCol[j];
while(pR && pC)
{
if(pR->col == pC->row)
{
count += (pR->num * pC->num);
pR = pR->right;
pC = pC->down;
}
else if(pR->col < pC->row)
{
pR = pR->right;
}
else
{
pC = pC->down;
}
}// end of while
if(count > 0)
{
crossResult.Add(i, j, count);
}
}
}
}
void Cross::GetMatrix(int * arr)
{
memset(arr, 0, sizeof(int) * m_row * m_col);
for(int i = 0; i < m_row; ++i)
{
Node * temp = pRow[i];
while(temp)
{
arr[i * m_row + temp->col] = temp->num;
temp = temp->right;
}
}
}
void Cross::Show()
{
for(int i = 0; i < m_row; ++i)
{
Node * temp = pRow[i];
while(temp)
{
cout << "(" << temp->row <<
"," << temp->col << "," <<
temp->num << ")" << '\t';
temp = temp->right;
}
cout << endl;
}
}
//mm.cpp
#include "cross.h"
int main()
{
int num1[5][5] = {
{1,0,0,0,0},
{0,2,0,0,0},
{0,3,0,0,0},
{0,0,0,4,0},
{0,0,0,0,5}};
int num2[5][5] = {
{1,0,0,0,0},
{2,0,0,0,0},
{0,3,0,0,0},
{0,0,0,0,4},
{0,0,5,0,0}};
int num3[5][5];
Cross cr1(5, 5);
cr1.GetChainFromMatrix((const int *)num1);
cr1.Show();
cout << endl;
Cross cr2(5, 5);
cr2.GetChainFromMatrix((const int *)num2);
cr2.Show();
cout << endl;
Cross cr3(5, 5);
cr1.Multiplication(cr2, cr3);
cr3.Show();
cout << endl;
cr3.GetMatrix((int *)num3);
for(int i = 0; i < 5; ++i)
{
for(int j = 0; j < 5; ++j)
{
cout << num3[i][j] << '\t';
}
cout << endl;
}
return 0;
}
|