1. 程式人生 > >演算法導論——python實踐(4.2矩陣乘法的Strassen演算法)

演算法導論——python實踐(4.2矩陣乘法的Strassen演算法)

4.2.1 矩陣乘法的暴力解法

#暴力解法
def matrix_multiply(a,b):
    n=len(a)
    c=[[0]*n for i in range(n)]#快速建立n階初始化方陣
    for i in range (0,n):
        for j in range(0,n):
            c[i][j]=0
            for k in range(0,n):
                c[i][j]+=a[i][k]*b[k][j]
    return c

這裡假定a和b都是方陣,如果選擇暴力解法,三個for迴圈,迴圈次數為n,總共需要花費θ(n^3)時間。

4.2.2 矩陣乘法的簡單分治法

演算法策略:(前提:假定A,B都是n等於2的次冪的方陣)

(1)基本思路:計算C=A*B時,將C,A,B矩陣進行分塊操作,對每個分塊的矩陣進行乘法運算,運算完畢後重新對得到的C11,C12,C21,C22進行組合操作。

(2)確定遞迴終止條件:當分塊矩陣得到的階數為1 時,得到的C即是A和B中兩個元素的乘積。

def division(a):    #矩陣分塊函式
    n=len(a)//2
    a11=[[0 for i in range(n)]for j in range(n)]
    a12=[[0 for i in range(n)]for j in range(n)]
    a21=[[0 for i in range(n)]for j in range(n)]
    a22=[[0 for i in range(n)]for j in range(n)]
    for i in range(n):
        for j in range(n):
            a11[i][j]=a[i][j]
            a12[i][j]=a[i][j+n]
            a21[i][j]=a[i+n][j]
            a22[i][j]=a[i+n][j+n]
    return (a11,a12,a21,a22)

def matrix_combination(a11,a12,a21,a22):
    n2 = len(a11)
    n=n2*2
    a = [[0 for col in range(n)] for row in range(n)]
    for i in range (0,n):
        for j in range (0,n):
            if i <= (n2-1) and j <= (n2-1):
                a[i][j] = a11[i][j]
            elif i <= (n2-1) and j > (n2-1):
                a[i][j] = a12[i][j-2]
            elif i > (n2-1) and j <= (n2-1):
                a[i][j] = a21[i-n2][j]
            else:
                a[i][j] = a22[i-n2][j-n2]
    return a
def matrix_add(a,b):  #矩陣相加函式
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    for i in range(0,n):
        for j in range(0,n):
            c[i][j] = a[i][j]+b[i][j]
    return c

def matrix_devision_multiply(a,b):   #矩陣乘法的簡單分治法主程式
    n=len(a)
    c = [[0 for col in range(n)] for row in range(n)]#c=[[0]*n for i in range(n)]
    if n==1:
        c[0][0]=a[0][0]*b[0][0]
    else:
        (a11,a21,a12,a22)=division(a)
        (b11,b21,b12,b22)=division(b)
        (c11,c21,c12,c22)=division(c)
        c11=matrix_add(matrix_devision_multiply(a11,b11),matrix_devision_multiply(a12,b21))
        c12=matrix_add(matrix_devision_multiply(a11,b12),matrix_devision_multiply(a12,b22))
        c21=matrix_add(matrix_devision_multiply(a21,b11),matrix_devision_multiply(a22,b21))
        c22=matrix_add(matrix_devision_multiply(a21,b12),matrix_devision_multiply(a22,b22))
        c=matrix_combination(c11,c12,c21,c22)
    return c

a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]]
b=a
print(matrix_devision_multiply(a,b))

4.2.3矩陣的Strassen演算法

  在簡單分治法的思想上,為進一步減少遞迴樹的分枝,在遞迴函式中只進行7次而不是8次的矩陣的乘法,而減少一次乘法的代價是增加額外的幾次矩陣加法運算。

def matrix_strassen(a,b):
    n=len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if n==1:
        c[0][0]=a[0][0]*b[0][0]
    else:
        (a11,a12,a21,a22)=division(a)
        (b11,b12,b21,b22)=division(b)
        (c11,c12,c21,c22)=division(c)
        s1=matrix_add_sub(b12,b22,0)
        s2=matrix_add_sub(a11,a12,1)
        s3=matrix_add_sub(a21,a22,1)
        s4=matrix_add_sub(b21,b11,0)
        s5=matrix_add_sub(a11,a22,1)
        s6=matrix_add_sub(b11,b22,1)
        s7=matrix_add_sub(a12,a22,0)
        s8=matrix_add_sub(b21,b22,1)
        s9=matrix_add_sub(a11,a21,0)
        s10=matrix_add_sub(b11,b12,1)
        p1=matrix_strassen(a11,s1)
        p2=matrix_strassen(s2,b22)
        p3=matrix_strassen(s3,b11)
        p4=matrix_strassen(a22,s4)
        p5=matrix_strassen(s5,s6)
        p6=matrix_strassen(s7,s8)
        p7=matrix_strassen(s9,s10)
        c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1)
        c12=matrix_add_sub(p1,p2,1)
        c21=matrix_add_sub(p3,p4,1)
        c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0)
        c=matrix_combination(c11,c12,c21,c22)
    return c

#矩陣的strssen演算法
def division(a):                              #對矩陣進行分解操作
    n=len(a)//2
    a11=[[0 for i in range(n)]for j in range(n)]
    a12=[[0 for i in range(n)]for j in range(n)]
    a21=[[0 for i in range(n)]for j in range(n)]
    a22=[[0 for i in range(n)]for j in range(n)]
    for i in range(n):
        for j in range(n):
            a11[i][j]=a[i][j]
            a12[i][j]=a[i][j+n]
            a21[i][j]=a[i+n][j]
            a22[i][j]=a[i+n][j+n]
    return (a11,a12,a21,a22)

def matrix_add_sub(a,b,keys):
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if keys==1:
        for i in range(n):
            for j in range(n):
                c[i][j] = a[i][j]+b[i][j]
    else:
        for i in range(n):
            for j in range(n):
                c[i][j]=a[i][j]-b[i][j]
    return c
def matrix_combination(a11,a12,a21,a22):    #對矩陣進行組合操作
    n2 = len(a11)
    n=n2*2
    a = [[0 for col in range(n)] for row in range(n)]
    for i in range (0,n):
        for j in range (0,n):
            if i <= (n2-1) and j <= (n2-1):
                a[i][j] = a11[i][j]
            elif i <= (n2-1) and j > (n2-1):
                a[i][j] = a12[i][j-n2]
            elif i > (n2-1) and j <= (n2-1):
                a[i][j] = a21[i-n2][j]
            else:
                a[i][j] = a22[i-n2][j-n2]
    return a

a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]]
b=a
print(matrix_strassen(a,b))

4.2.4 修改Strassen演算法,使之適應矩陣規模n不是2的冪的情況。

具體思路是將不是2的次冪的矩陣擴充套件成2的次冪的矩陣,在多出的行和列上添上0元素,在計算結果重新組合成c後,對c矩陣多出的行和列上的0元素捨去。因此在簡單分治程式的基礎上增加了matrix_expand和matrix_shrink函式。在主函式中,首先對輸入矩陣A,B的階數進行判斷,如果是2的次冪則不用進行任何操作,直接用普通的Strassen演算法,如果不是2的次冪,先對A,B進行矩陣拓展,在計算得到的結果後進行矩陣縮略。

#coding UTF-8
#矩陣的strassen演算法
from math import *
def matrix_strassen(a,b):
    n=len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if n==1:
        c[0][0]=a[0][0]*b[0][0]
    else:
        (a11,a12,a21,a22)=division(a)
        (b11,b12,b21,b22)=division(b)
        (c11,c12,c21,c22)=division(c)
        s1=matrix_add_sub(b12,b22,0)
        s2=matrix_add_sub(a11,a12,1)
        s3=matrix_add_sub(a21,a22,1)
        s4=matrix_add_sub(b21,b11,0)
        s5=matrix_add_sub(a11,a22,1)
        s6=matrix_add_sub(b11,b22,1)
        s7=matrix_add_sub(a12,a22,0)
        s8=matrix_add_sub(b21,b22,1)
        s9=matrix_add_sub(a11,a21,0)
        s10=matrix_add_sub(b11,b12,1)
        p1=matrix_strassen(a11,s1)
        p2=matrix_strassen(s2,b22)
        p3=matrix_strassen(s3,b11)
        p4=matrix_strassen(a22,s4)
        p5=matrix_strassen(s5,s6)
        p6=matrix_strassen(s7,s8)
        p7=matrix_strassen(s9,s10)
        c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1)
        c12=matrix_add_sub(p1,p2,1)
        c21=matrix_add_sub(p3,p4,1)
        c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0)
        c=matrix_combination(c11,c12,c21,c22)
    return c

def matrix_expand(a):       #對a,b執行矩陣擴充套件程式段
    n=len(a)
    m=ceil(log(n,2))
    p=int(pow(2,m))
    c=[[0 for col in range(p)]for row in range(p)]#執行expand模式
    for i in range(p):
        for j in range(p):
            if i>=n or j>=n:
                c[i][j]=0
            else:
                c[i][j]=a[i][j]
    return c
def matrix_shrink(a,b):
    n=len(b)
    c=[[0 for col in range(n)]for row in range(n)]
    for i in range(n):
        for j in range(n):
            c[i][j]=a[i][j]
    return c

def division(a):                              #對矩陣進行分解操作
    n=len(a)//2
    a11=[[0 for i in range(n)]for j in range(n)]
    a12=[[0 for i in range(n)]for j in range(n)]
    a21=[[0 for i in range(n)]for j in range(n)]
    a22=[[0 for i in range(n)]for j in range(n)]
    for i in range(n):
        for j in range(n):
            a11[i][j]=a[i][j]
            a12[i][j]=a[i][j+n]
            a21[i][j]=a[i+n][j]
            a22[i][j]=a[i+n][j+n]
    return (a11,a12,a21,a22)

def matrix_add_sub(a,b,keys):  #矩陣加減程式,keys=1時執行矩陣相加,keys=0時執行矩陣相減
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if keys==1:
        for i in range(n):
            for j in range(n):
                c[i][j] = a[i][j]+b[i][j]
    else:
        for i in range(n):
            for j in range(n):
                c[i][j]=a[i][j]-b[i][j]
    return c
def matrix_combination(a11,a12,a21,a22):    #對矩陣進行組合操作
    n2 = len(a11)
    n=n2*2
    a = [[0 for col in range(n)] for row in range(n)]
    for i in range (0,n):
        for j in range (0,n):
            if i <= (n2-1) and j <= (n2-1):
                a[i][j] = a11[i][j]
            elif i <= (n2-1) and j > (n2-1):
                a[i][j] = a12[i][j-n2]
            elif i > (n2-1) and j <= (n2-1):
                a[i][j] = a21[i-n2][j]
            else:
                a[i][j] = a22[i-n2][j-n2]
    return a

a=[[1,1,1,1,1],[1,1,1,1,1],[2,2,2,2,2],[2,2,2,2,2],[3,3,3,3,3]]
b=a
n=len(a)
if not(log(n,2)-floor(log(n,2))):  #如果是2的次冪
    print(matrix_strassen(a,b))
else:
    print(matrix_shrink(matrix_strassen(matrix_expand(a),matrix_expand(b)),a))

此種演算法是自己能想到的最簡單的思路,但是增加了計算量,例如5階的方陣會轉化為8階的方陣進行計算,增加了不必要的繁瑣的0和0的乘法。自己現在也沒在網上看相關的資料,能力有限,望諸位看官海涵。

參考文獻:

1.演算法導論 機械工業出版社 第四章第二節 矩陣的Strassen演算法。