1. 程式人生 > >牛客國慶集訓派對Day2 A 矩陣乘法(思維分塊)

牛客國慶集訓派對Day2 A 矩陣乘法(思維分塊)

題目連結

題意:

給你兩個矩陣A,B,

A是n*p,B是p*m,B是一個只有0,1組成的矩陣,Aij<65536

C=A*B,讓你求出C的裡面所有元素的異或和

 

解析:

官方的標解是分塊,每8個分一組。

例如對於A,每行行每8個分成一組,對於B,每一列每8個分成一組,

定義組數為x=p/8+(p%8)1:0

那麼現在A就變成了n*x,B變成x*m

現在我們需要解決的就是當分完塊的A,B相乘時,對應組的乘積

那麼對於這一個我們就可以預處理,因為B的每一列中,每8個一組,那麼每一組的情況只有256種,

我們就可以把A的每一組都對應求在256種情況下,每一種情況的值。這個可以提前打表出來。

那麼預處理的複雜度就是O(n*p*256)

然後最後相乘的複雜度就變成了O(n*m*p/8) ,會達到1e8

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAXN = 4096+100;
int n,p,m;

int a[MAXN][80];
int b[80][MAXN];
int d[MAXN][20][257];
int e[20][MAXN];


int main()
{
    scanf("%d%d%d",&n,&p,&m);
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=p;j++)
        {
            scanf("%x",&a[i][j]);
        }
    }

    for(int i=1;i<=m;i++)
    {
        for(int j=1;j<=p;j++)
        {
            scanf("%01d",&b[j][i]);
        }
    }

    int btr=p/8+(p%8?1:0);

    for(int i=1;i<=n;i++)
    {
        for(int w=1;w<=btr;w++)
        {
            for(int j=0;j<256;j++)
            {
                int tmp=j;
                for(int k=1;k<=8;k++)
                {
                    if(tmp&1) d[i][w][j]+=a[i][(w-1)*8+k];
                    tmp>>=1;
                }
            }
        }
    }

    for(int i=1;i<=m;i++)
    {
        for(int j=1;j<=btr;j++)
        {
            for(int k=1;k<=8;k++)
            {
                e[j][i]|=(b[(j-1)*8+k][i]<<(k-1));
            }
        }
    }

    int ans=0;
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=m;j++)
        {
            int res=0;
            for(int k=1;k<=btr;k++)
            {
                res+=d[i][k][e[k][j]];
            }
            ans^=res;
        }
    }
    printf("%d\n",ans);
    return 0;
}

我自己那時候想的方法也過了。。複雜度是O(n*m*p/4)達到2e8,因為最暴力的O(n*m*p)都能過,1e9...

我就是把每一個數按照二進位制拆出來。因為Aij最大隻有2^16,那麼一個A最多隻能被分成16個矩陣

第一矩陣A1ij就表示,Aij的二進位制第1位;第二個矩陣A2ij,表示Aij的二進位制第二位..........

那麼我們就把A1ij*B,然後把各個位置的進位儲存在inc[][]裡面,因為兩個二進位制矩陣相乘複雜度就是O(n*m)

對於進位,我們只需要他們相乘統計1的時候,加上去就可以了,然後在更新當前產生的新的進位

最後我們就需要把超過16位的進位進上去,只需要遍歷一遍inc這個陣列,複雜度O(n*m)

寫的時候,找了一個BUG半天,發現存拆出來的數組裡面的元素最大是有64位的,因為p<=64

所以需要用ull來存。。。。。。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAXN = 4096+10;
int n,p,m;
 
ull a[20][MAXN];
int  tmp[80];
ull b[MAXN];
char str[80];
int inc[MAXN][MAXN];
int main()
{
    scanf("%d%d%d",&n,&p,&m);
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=p;j++)
        {
            scanf("%x",&tmp[j]);
        }
        for(int j=1;j<=16;j++)
        {
            for(int k=1;k<=p;k++)
            {
                a[j][i]=a[j][i]<<1;
                if(tmp[k]&1) a[j][i]|=1;
                tmp[k]=tmp[k]>>1;
            }
        }
    }
    for(int i=1;i<=m;i++)
    {
        getchar();
        scanf("%s",str);
        for(int j=0;j<p;j++)
        {
            b[i]=b[i]<<1;
            if(str[j]=='1')
            {
                b[i]|=1;
            }
 
        }
    }
    ull w;
    int num;
    int coun=1;
    int flag;
    int ans=0;
    for(int i=1;i<=16;i++)
    {
        flag=0;
        for(int j=1;j<=n;j++)
        {
            for(int k=1;k<=m;k++)
            {
                w=a[i][j]&b[k];
                num=__builtin_popcountll(w);
                num+=inc[j][k];
                inc[j][k]=num>>1;
                flag^=(num&1);
            }
        }
        ans^=(flag)?coun:0;
        coun=coun<<1;
 
    }
 
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=m;j++)
        {
            /*int ano=coun;
            while(inc[i][j])
            {
                ans^=((inc[i][j]&1)?ano:0);
                inc[i][j]=inc[i][j]>>1;
                ano<<=1;
            }*/
            ans^=(inc[i][j]<<16);
        }
    }
    printf("%d\n",ans);
    return 0;
}