1. 程式人生 > >矩陣乘法優化演算法

矩陣乘法優化演算法

本篇文章大部分思路與程式碼都來自於微信公眾號“CPP開發者”中2016年4月11日的文章《矩陣相乘優化演算法實現講解》,基本相當於這篇文章的重點重述。

矩陣是什麼以及矩陣乘法是怎麼操作的,我想點開這篇文章的人都應該知道了,這裡就不再贅述了。

首先回顧一下我們最樸素的演算法:

//計算矩陣a乘矩陣b,將結果存入c;p是第一個矩陣的行數,q是第二個矩陣的行數,r是第二個矩陣的列數
void mult(int a[MAXN][MAXN],int b[MAXN][MAXN],int c[MAXN][MAXN],int p,int q,int r)
{
    int i,j,k;
    //先對c進行初始化
    for(i=0;i<p;i++)
    {
        for(j=0;j<r;j++)
        {
            c[i][j] = 0;
        }
    }
    //計算矩陣乘法
    for(i=0;i<p;i++)
    {
        for(j=0;j<r;j++)
        {
            for(k=0;k<q;k++)
            {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
}

這個演算法就是直接模擬矩陣乘法的定義,時間複雜度是O(n^3),同時也是Ω(n^3)。

接下來介紹優化演算法:

這個優化演算法的最差時間複雜度也是O(n^3),但是對於矩陣中零比較多的情況會有所改善。

基本思路是遍歷其中一個矩陣的所有元素,計算所有結果中用到這個元素的部分。如果這個元素是零,那麼就沒有必要計算了,略過去。這麼說可能不清楚,所以還是還是那個程式碼吧。

int mult(int a[MAXN][MAXN],int b[MAXN][MAXN],int c[MAXN][MAXN],int p,int q,int r)
{
    int i,j,k;

    for(i=0;i<p;i++)
    {
        for(j=0;j<r;j++)
        {
            c[i][j] = 0;
        }
    }

    for(i=0;i<p;i++)
    {
        for(k=0;k<q;k++)
        {
            if(a[i][k]!=0)    //如果該元素是零,就省去以下計算
            {
                for(j=0;j<r;j++)
                {
                    c[i][j] += a[i][k] * b[k][j];
                }
            }
        }
    }
}

比起其他最差時間複雜度有有效降低的演算法,這一優化演算法更便於實現,而且對於零比較多的矩陣會有很好的效果。