1. 程式人生 > >壓縮稀疏矩陣以及使用三元組實現矩陣乘法,簡單易懂

壓縮稀疏矩陣以及使用三元組實現矩陣乘法,簡單易懂

思路:既然使用三元組去實現,所以首先要定義一個三元組

typedef struct  node {
    int row, col, v;//分別代表行數,列數,以及元素的值,整個式子表示在原矩陣的第row行,第col列,有一個值為v的數
} node;

然後要想實現乘法,矩陣的規模也是要記錄的,所以就有了下面這個結構體

struct T {
    node Node[maxn];//存放三元組
    int MAXrow, MAXcol, MAXsize;//分別表示原矩陣的行數,列數,以及矩陣中非0元素的數量
} TA, TB, TC;//表示矩陣a,b,c(c是a*b的結果)

然後主要變數定義完了,就可以輸入資料啦

//an,am,bn,bm表示矩陣a的行數,列數,矩陣b的行數,列數
scanf("%d%d", &an, &am);//輸入a矩陣
    TA.MAXrow = an;
    TA.MAXcol = am;
    for(int i = 1; i <= an; i++) {//下標從1開始
        for(int j = 1; j <= am; j++) {
            scanf("%d", &a[i][j]);
            if(a[i][j]) {//如果輸入的當前位置不是0,就記錄這個非0元素

                TA.Node
[TA.MAXsize].row = i; TA.Node[TA.MAXsize].col = j; TA.Node[TA.MAXsize].v = a[i][j]; TA.MAXsize++; } } } scanf("%d%d", &bn, &bm);//輸入b矩陣 TB.MAXrow = bn; TB.MAXcol = bm; for(int i = 1; i <= bn; i++) {//下標從1開始
for(int j = 1; j <= bm; j++) { scanf("%d", &b[i][j]); if(b[i][j]) { TB.Node[TB.MAXsize].row = i; TB.Node[TB.MAXsize].col = j; TB.Node[TB.MAXsize].v = b[i][j]; TB.MAXsize++; } } }


最最關鍵的C矩陣就要來了

現在我們來明確一下矩陣乘法是如何定義的

假設現在有一個矩陣a(2*3規模)
a11 a12 a13
a21 a22 a23
還有一個矩陣b(3*2規模)
b11 b12
b21 b22
b31 b32
那麼cij=ai1*b1j+ai2*b2j+……+ain+bnj

由於矩陣是用三元組存的,所以計算的時候就會麻煩一點,每次計算cij都要重新遍歷a,b兩個矩陣。

TC.MAXrow = an;
    TC.MAXrow = bm;
    for(int i = 1; i <= an; i++) {//計算c矩陣
        for(int j = 1; j <= bm; j++) {
            int sum = 0;
            for(int p = 0; p < TA.MAXsize ; p++) {
                if(TA.Node[p].row != i) continue;
                for(int q = 0; q < TB.MAXsize; q++) {
                    if(TB.Node[q].col != j) continue;
                    if(TA.Node[p].col == TB.Node[q].row) {
                        sum += TA.Node[p].v * TB.Node[q].v;
                    }
                }
            }
            if(sum != 0) {
                TC.Node[TC.MAXsize].row = i;
                TC.Node[TC.MAXsize].col = j;
                TC.Node[TC.MAXsize].v = sum;
                TC.MAXsize++;
            }
        }
    }

完整程式碼:

#include <cstdio>
using namespace std;
const int maxn = 1000;
int a[maxn][maxn];
int b[maxn][maxn];
int c[maxn][maxn];
typedef struct  node {
    int row, col, v;
} node;
struct T {
    node Node[maxn];
    int MAXrow, MAXcol, MAXsize;
} TA, TB, TC;
int an, am, bn, bm;
int main() {
    scanf("%d%d", &an, &am);//輸入a矩陣
    TA.MAXrow = an;
    TA.MAXcol = am;
    for(int i = 1; i <= an; i++) {//下標從1開始
        for(int j = 1; j <= am; j++) {
            scanf("%d", &a[i][j]);
            if(a[i][j]) {

                TA.Node[TA.MAXsize].row = i;
                TA.Node[TA.MAXsize].col = j;
                TA.Node[TA.MAXsize].v = a[i][j];
                TA.MAXsize++;
            }
        }
    }
    scanf("%d%d", &bn, &bm);//輸入b矩陣
    TB.MAXrow = bn;
    TB.MAXcol = bm;
    for(int i = 1; i <= bn; i++) {//下標從1開始
        for(int j = 1; j <= bm; j++) {
            scanf("%d", &b[i][j]);
            if(b[i][j]) {

                TB.Node[TB.MAXsize].row = i;
                TB.Node[TB.MAXsize].col = j;
                TB.Node[TB.MAXsize].v = b[i][j];
                TB.MAXsize++;
            }
        }
    }
    TC.MAXrow = an;
    TC.MAXrow = bm;
    for(int i = 1; i <= an; i++) {//計算c矩陣
        for(int j = 1; j <= bm; j++) {
            int sum = 0;
            for(int p = 0; p < TA.MAXsize ; p++) {
                if(TA.Node[p].row != i) continue;
                for(int q = 0; q < TB.MAXsize; q++) {
                    if(TB.Node[q].col != j) continue;
                    if(TA.Node[p].col == TB.Node[q].row) {
                        sum += TA.Node[p].v * TB.Node[q].v;
                    }
                }
            }
            if(sum != 0) {
                TC.Node[TC.MAXsize].row = i;
                TC.Node[TC.MAXsize].col = j;
                TC.Node[TC.MAXsize].v = sum;
                TC.MAXsize++;
            }
        }
    }
    printf("\n\n");
    int t = 0;
    for(int i = 1; i <= an; i++) {//輸出c矩陣
        for(int j = 1; j <= bm; j++) {
            if(TC.Node[t].row == i && TC.Node[t].col == j) {
                printf("%d ", TC.Node[t].v);
                t++;
            } else printf("0 ");

        }
        printf("\n");
    }
    printf("\n");
    return 0;
}