稀疏矩阵运算

稀疏矩阵运算

在实现的过程中,避免使用C++的高级容器.

#include <iostream>
#include <cstdlib>

using namespace std;

class Node{
public:
    int row;
    int col;
    int value;
    Node(){};
    Node(int x,int y,int v){
        row = x;
        col = y;
        value = v;
    }
    ~Node(){};
    void set(int x,int y,int v){
        row = x;
        col = y;
        value = v;
    }
};

class SparseMatrix{
private:
    int row;
    int col;
    int len;
    Node* list;
    bool comp(const Node&a,const Node&b){
        if(a.row != b.row){
            return a.row < b.row;
        }else return a.col <= b.col;
    }
    void Merge(int head,int mid,int tail,Node* arr){
        Node* temp = new Node[tail-head+1];
        int cnt = 0;
        int l = head,r = mid+1;
        while(l <= mid && r <= tail){
            if(comp(arr[l], arr[r])){
                temp[cnt++] = arr[l++];
            }else{
                temp[cnt++] = arr[r++];
            }
        }
        while(l <= mid){
            temp[cnt++] = arr[l++];
        }
        while(r <= tail){
            temp[cnt++] = arr[r++];
        }
        for(int i = 0; i < cnt; i++){
            arr[i+head] = temp[i];
        }
        delete[] temp;
    }

    void MergeSort(int head,int tail,Node* arr){
        if(head < tail){
            int mid = (head + tail) / 2;
            MergeSort(head, mid, arr);
            MergeSort(mid+1, tail, arr);
            Merge(head, mid, tail, arr);
        }
    }
public:
    SparseMatrix(){
        row = col = len = 0;
        list = NULL;
    };
    ~SparseMatrix(){
        row = col = len = 0;
        delete[] list;
        list = NULL;
    }
    void CreateSMatrix();
    bool PrintSMatrix();
    bool CopySMatrix(SparseMatrix& ans);
    bool AddSMatrix(const SparseMatrix& anotherSM,SparseMatrix& ans);
    bool SubSMatrix(const SparseMatrix& anotherSM,SparseMatrix& ans);
    bool MultSMatrix(const SparseMatrix& anotherSM,SparseMatrix& ans);
    bool TransposeSMatrix(SparseMatrix& ans);
};
void SparseMatrix::CreateSMatrix(){ // 复杂度O(n), n为非零元个数
    int x,y,v;
    cout << "Please input row,col" << endl;
    cin >> row >> col;
    cout << "Please input the number of non-zero elements" << endl;
    cin >> len;
    cout << "Please input the info of each elements" << endl;
    list = new Node[len];
    for(int i = 0; i < len; i++){
        cin >> x >> y >> v;
        list[i].set(x, y, v);
    }
    MergeSort(0,len-1,list);
}

bool SparseMatrix::PrintSMatrix(){ // 复杂度 O(m*n) ,矩阵的规模
    if(list == NULL){
        return false;
    }else{
        cout << endl << "SMatrix: " << endl;

        int ptr = 0;
        int x = 1, y = 1;
        while(ptr < len){
            while(x != list[ptr].row || y != list[ptr].col){
                cout << 0;
                y++;
                if(y > col){
                    cout << '\n';
                    x++;
                    y = 1;
                }else{
                    cout << ' ';
                }
            }
            cout << list[ptr].value;
            ptr++;
            y++;
            if(y > col){
                x++;
                y = 1;
                if(x <= row){
                    cout << '\n';
                }
            }else{
                cout << ' ';
            }
        }

        while(x <= row){
            while(y <= col){
                cout << 0;
                y++;
                if(y <= col)
                    cout << ' ';
            }
            x++;
            y = 1;
            if(x <= row){
                cout << '\n';
            }
        }
        return true;
    }
}

bool SparseMatrix::CopySMatrix(SparseMatrix& ans){ // 复杂度O(n), n为非零元个数
    if(list == NULL){
        return false;
    }else{
        if(ans.list != NULL){
            delete[] ans.list;
        }
        ans.row = row;
        ans.col = col;
        ans.len = len;
        ans.list = new Node[len];
        for(int i = 0; i < len; i++){
            ans.list[i].row = list[i].row;
            ans.list[i].col = list[i].col;
            ans.list[i].value = list[i].value;
        }
        return true;
    }
}

bool SparseMatrix::AddSMatrix(const SparseMatrix& anotherSM,SparseMatrix& ans){ // 复杂度O(n+m), n+m为两个矩阵的非零元个数
    if(list == NULL || anotherSM.list == NULL){
        return false;
    }else if(row != anotherSM.row || col != anotherSM.col){
        return false;
    }else{
        if(ans.list != NULL){
            delete[] ans.list;
        }
        ans.row = row;
        ans.col = col;
        int cnt = 0;
        int l = 0;
        int r = 0;
        Node* temp = new Node[len+anotherSM.len];
        while(l < len && r < anotherSM.len){
            if(comp(list[l],anotherSM.list[r])){
                temp[cnt++] = list[l++];
            }else if(list[l].row == anotherSM.list[r].col && list[l].col == anotherSM.list[r].col){
                temp[cnt].row = list[l].row;
                temp[cnt].col = list[l].col;
                temp[cnt].value = list[l].value + anotherSM.list[r].value;
                cnt++;
                l++;
                r++;
            }else{
                temp[cnt++] = anotherSM.list[r++];
            }
        }

        while(l < len){
            temp[cnt++] = list[l++];
        }
        while(r < anotherSM.len){
            temp[cnt++] = anotherSM.list[r++];
        }
        ans.len = cnt;
        ans.list = new Node[cnt];
        for(int i = 0; i < cnt; i++){
            ans.list[i].row = temp[i].row;
            ans.list[i].col = temp[i].col;
            ans.list[i].value = temp[i].value;
        }
        delete[] temp;
        return true;
    }
}


bool SparseMatrix::SubSMatrix(const SparseMatrix& anotherSM,SparseMatrix& ans){ // 复杂度O(n+m), n+m为两个矩阵的非零元个数
    if(list == NULL || anotherSM.list == NULL){
        return false;
    }else if(row != anotherSM.row || col != anotherSM.col){
        return false;
    }else{
        if(ans.list != NULL){
            delete[] ans.list;
        }
        ans.row = row;
        ans.col = col;
        int cnt = 0;
        int l = 0;
        int r = 0;
        Node* temp = new Node[len+anotherSM.len];
        while(l < len && r < anotherSM.len){
            if(comp(list[l],anotherSM.list[r])){
                temp[cnt++] = list[l++];
            }else if(list[l].row == anotherSM.list[r].col && list[l].col == anotherSM.list[r].col){
                temp[cnt].row = list[l].row;
                temp[cnt].col = list[l].col;
                temp[cnt].value = list[l].value - anotherSM.list[r].value;
                cnt++;
                l++;
                r++;
            }else{
                temp[cnt].row = anotherSM.list[r].row;
                temp[cnt].col = anotherSM.list[r].col;
                temp[cnt].value = -anotherSM.list[r].value;
                cnt++;
                r++;
            }
        }

        while(l < len){
            temp[cnt++] = list[l++];
        }
        while(r < anotherSM.len){
            temp[cnt].row = anotherSM.list[r].row;
            temp[cnt].col = anotherSM.list[r].col;
            temp[cnt].value = -anotherSM.list[r].value;
            cnt++;
            r++;
        }
        ans.len = cnt;
        ans.list = new Node[cnt];
        for(int i = 0; i < cnt; i++){
            ans.list[i].row = temp[i].row;
            ans.list[i].col = temp[i].col;
            ans.list[i].value = temp[i].value;
        }
        delete[] temp;
        return true;
    }
}

bool SparseMatrix::MultSMatrix(const SparseMatrix& anotherSM,SparseMatrix& ans){ // 复杂度O(n*m), n,m为两个矩阵的非零元个数
    // m*s s*n -> m*n
    if(list == NULL || anotherSM.list == NULL){
        return false;
    }else if(col != anotherSM.row){
        return false;
    }else{
        if(ans.list != NULL){
            delete[] ans.list;
        }

        ans.row = row;
        ans.col = anotherSM.col;
        Node* temp = new Node[100];
        int cnt = 0;
        // O(n*m)
        for(int i = 0; i < len; i++){
            for(int j = 0; j < anotherSM.len; j++){
                if(list[i].col == anotherSM.list[j].row){
                    temp[cnt].row = list[i].row;
                    temp[cnt].col = anotherSM.list[j].col;
                    temp[cnt].value = list[i].value * anotherSM.list[j].value;
                    cnt++;
                }
            }
        }
        // O((n+m)log(n+m))
        MergeSort(0,cnt-1,temp);
        int k = 1;
        
        // O(n+m)
        for(int i = 1; i < cnt; i++){
            if(temp[i].row != temp[i-1].row || temp[i].col != temp[i-1].col){
                k++;
            }
        }
        ans.len = k;
        ans.list = new Node[k];
        ans.list[0] = temp[0];
        int ptr = 0;
        
        // O(n+m)
        for(int i = 1; i < cnt; i++){
            if(temp[i].row != temp[i-1].row || temp[i].col != temp[i-1].col){
                ptr++;
                ans.list[ptr].row = temp[i].row;
                ans.list[ptr].col = temp[i].col;
                ans.list[ptr].value = temp[i].value;
            }else{
                ans.list[ptr].value += temp[i].value;
            }
        }
        
        delete[] temp;
        return false;
    }
}

bool SparseMatrix::TransposeSMatrix(SparseMatrix& ans){ // 复杂度O(n), n为矩阵的非零元个数
    if(list == NULL){
        return false;
    }else{
        if(ans.list != NULL){
            delete[] ans.list;
        }
        ans.row = col;
        ans.col = row;
        ans.len = len;
        ans.list = new Node[len];
        for(int i = 0; i < len; i++){
            ans.list[i].row = list[i].col;
            ans.list[i].col = list[i].row;
            ans.list[i].value = list[i].value;
        }
        MergeSort(0,len-1,ans.list);
        return true;
    }
}
int main(){
    SparseMatrix A;
    A.CreateSMatrix();
    A.PrintSMatrix();

    SparseMatrix B;
    A.TransposeSMatrix(B);
    B.PrintSMatrix();

    SparseMatrix C;
    A.MultSMatrix(B,C);
    
    C.PrintSMatrix();
    return 0;
}





posted @ 2020-10-06 09:39  popozyl  阅读(176)  评论(0编辑  收藏  举报