洛谷2151[SDOI2009]HH去散步(dp+矩陣乘法優化)
阿新 • • 發佈:2018-12-23
一道良好的矩陣乘法優化\(dp\)的題。
首先,一個比較\(naive\)的想法。
我們定義\(dp[i][j]\)表示已經走了\(i\)步,當前在點\(j\)的方案數。
由於題目中限制了不能立即走之前走過來的那個點,所以這個狀態並不能優秀的轉移。
嘗試重新定義\(dp\)狀態。
令\(dp[i][j]\)表示已經走了\(i\)步,當前在\(j\)這條邊的終點的那個點。
假設\(to[j]=p\)
那麼\(dp[i][j]\)可以轉移到\(dp[i+1][out[p]] 其中\ (out[p]不為j的反向邊)\)
其中\(out[p]\)表示p的出邊(我們把題目中的每條無向拆成兩個有向邊)
最後求\(ans\)的時候,只需要列舉哪些邊的終點是目標點,然後加起來即可
通過具體的邊的限制,我們就能滿足題目中的那個要求。
qwq但是我們發現,如果暴力轉移的話,時間複雜度是不能夠接受的。
考慮到每次只從\(i\)轉移到\(i+1\)。
所以可以構造一個轉移矩陣。
對於一個狀態\(dp[x][i]\),然後在如果他能對編號為\(j\)的邊產生貢獻,那麼我們把構造矩陣\(a[i][j]\)++
for (int i=1;i<=cnt;i++) { int to = y[i]; for (int j=0;j<out[to].size();j++) { int now = out[to][j]; if((i+1)==((now+1)^1)) continue; b.a[i][now]++; } }
注意不能通過具體的點來判斷,而要判斷是否為反向邊。
然後我們強行令初始矩陣為dp[1][i]的值,就是強行走一步,然後快速冪出來\(k-1\)次方的值,二者相乘,最後求解即可。
// luogu-judger-enable-o2 #include<bits/stdc++.h> #define mk make_pair #define pb push_back #define ll long long #define int long long using namespace std; inline int read() { int x=0,f=1;char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();} while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } const int maxn = 150; const int maxm = 1e5+1e2; const int mod = 45989; struct Ju{ int x,y; int a[maxn][maxn]; Ju operator * (Ju b) { Ju ans; memset(ans.a,0,sizeof(ans.a)); ans.x=x; ans.y=b.y; for (register int i=1;i<=ans.x;++i) for (register int j=1;j<=ans.y;++j) for (register int k=1;k<=y;++k) ans.a[i][j]=(ans.a[i][j]+a[i][k]*b.a[k][j]%mod)%mod; return ans; } }; Ju qsm(Ju i,int j) { Ju ans; memset(ans.a,0,sizeof(ans.a)); ans.x=i.x; ans.y=i.y; for (int p=1;p<=i.x;p++) ans.a[p][p]=1; while(j) { if (j&1) ans=ans*i; i=i*i; j>>=1; } return ans; }; Ju a,b; int n,m,k,s,t; int x[maxm],y[maxm],w[maxm]; int cnt=0; vector<int> in[maxn],out[maxn]; signed main() { n=read();m=read(),k=read(),s=read(),t=read(); s++; t++; for (int i=1;i<=m;i++) { int u=read(),v=read(); u++; v++; ++cnt; x[cnt]=u,y[cnt]=v; ++cnt; x[cnt]=v,y[cnt]=u; } for (int i=1;i<=cnt;i++) { out[x[i]].pb(i); in[y[i]].pb(i); } for (int i=1;i<=cnt;i++) { int to = y[i]; for (int j=0;j<out[to].size();j++) { int now = out[to][j]; if((i+1)==((now+1)^1)) continue; b.a[i][now]++; } } //for (int i=1;i<=cnt;i++) // { // for (int j=1;j<=cnt;j++) cout<<b.a[i][j]<<" "; // cout<<endl; //} for (int i=0;i<out[s].size();i++) { a.a[1][out[s][i]]++; //cout<<out[s][i]<<" "<<endl; } //cout<<"******************"<<endl; //for (int i=1;i<=cnt;i++) cout<<a.a[1][i]<<" "; //cout<<endl; a.x=1; a.y=cnt; b.x=cnt; b.y=cnt; b=qsm(b,k-1); a=a*b; int ans = 0; for (int i=1;i<=cnt;i++) { if (y[i]==t) ans=(ans+a.a[1][i])%mod; //cout<<ans<<endl; } cout<<ans; return 0; }