牛客國慶集訓派對Day2 A 矩陣乘法(思維分塊)
阿新 • • 發佈:2018-11-19
題意:
給你兩個矩陣A,B,
A是n*p,B是p*m,B是一個只有0,1組成的矩陣,Aij<65536
C=A*B,讓你求出C的裡面所有元素的異或和
解析:
官方的標解是分塊,每8個分一組。
例如對於A,每行行每8個分成一組,對於B,每一列每8個分成一組,
定義組數為x=p/8+(p%8)1:0
那麼現在A就變成了n*x,B變成x*m
現在我們需要解決的就是當分完塊的A,B相乘時,對應組的乘積
那麼對於這一個我們就可以預處理,因為B的每一列中,每8個一組,那麼每一組的情況只有256種,
我們就可以把A的每一組都對應求在256種情況下,每一種情況的值。這個可以提前打表出來。
那麼預處理的複雜度就是O(n*p*256)
然後最後相乘的複雜度就變成了O(n*m*p/8) ,會達到1e8
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; const int MAXN = 4096+100; int n,p,m; int a[MAXN][80]; int b[80][MAXN]; int d[MAXN][20][257]; int e[20][MAXN]; int main() { scanf("%d%d%d",&n,&p,&m); for(int i=1;i<=n;i++) { for(int j=1;j<=p;j++) { scanf("%x",&a[i][j]); } } for(int i=1;i<=m;i++) { for(int j=1;j<=p;j++) { scanf("%01d",&b[j][i]); } } int btr=p/8+(p%8?1:0); for(int i=1;i<=n;i++) { for(int w=1;w<=btr;w++) { for(int j=0;j<256;j++) { int tmp=j; for(int k=1;k<=8;k++) { if(tmp&1) d[i][w][j]+=a[i][(w-1)*8+k]; tmp>>=1; } } } } for(int i=1;i<=m;i++) { for(int j=1;j<=btr;j++) { for(int k=1;k<=8;k++) { e[j][i]|=(b[(j-1)*8+k][i]<<(k-1)); } } } int ans=0; for(int i=1;i<=n;i++) { for(int j=1;j<=m;j++) { int res=0; for(int k=1;k<=btr;k++) { res+=d[i][k][e[k][j]]; } ans^=res; } } printf("%d\n",ans); return 0; }
我自己那時候想的方法也過了。。複雜度是O(n*m*p/4)達到2e8,因為最暴力的O(n*m*p)都能過,1e9...
我就是把每一個數按照二進位制拆出來。因為Aij最大隻有2^16,那麼一個A最多隻能被分成16個矩陣
第一矩陣A1ij就表示,Aij的二進位制第1位;第二個矩陣A2ij,表示Aij的二進位制第二位..........
那麼我們就把A1ij*B,然後把各個位置的進位儲存在inc[][]裡面,因為兩個二進位制矩陣相乘複雜度就是O(n*m)
對於進位,我們只需要他們相乘統計1的時候,加上去就可以了,然後在更新當前產生的新的進位
最後我們就需要把超過16位的進位進上去,只需要遍歷一遍inc這個陣列,複雜度O(n*m)
寫的時候,找了一個BUG半天,發現存拆出來的數組裡面的元素最大是有64位的,因為p<=64
所以需要用ull來存。。。。。。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAXN = 4096+10;
int n,p,m;
ull a[20][MAXN];
int tmp[80];
ull b[MAXN];
char str[80];
int inc[MAXN][MAXN];
int main()
{
scanf("%d%d%d",&n,&p,&m);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=p;j++)
{
scanf("%x",&tmp[j]);
}
for(int j=1;j<=16;j++)
{
for(int k=1;k<=p;k++)
{
a[j][i]=a[j][i]<<1;
if(tmp[k]&1) a[j][i]|=1;
tmp[k]=tmp[k]>>1;
}
}
}
for(int i=1;i<=m;i++)
{
getchar();
scanf("%s",str);
for(int j=0;j<p;j++)
{
b[i]=b[i]<<1;
if(str[j]=='1')
{
b[i]|=1;
}
}
}
ull w;
int num;
int coun=1;
int flag;
int ans=0;
for(int i=1;i<=16;i++)
{
flag=0;
for(int j=1;j<=n;j++)
{
for(int k=1;k<=m;k++)
{
w=a[i][j]&b[k];
num=__builtin_popcountll(w);
num+=inc[j][k];
inc[j][k]=num>>1;
flag^=(num&1);
}
}
ans^=(flag)?coun:0;
coun=coun<<1;
}
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
{
/*int ano=coun;
while(inc[i][j])
{
ans^=((inc[i][j]&1)?ano:0);
inc[i][j]=inc[i][j]>>1;
ano<<=1;
}*/
ans^=(inc[i][j]<<16);
}
}
printf("%d\n",ans);
return 0;
}