1. 程式人生 > >BZOJ 1009 [HNOI2008]GT考試(矩陣快速冪優化DP+KMP)

BZOJ 1009 [HNOI2008]GT考試(矩陣快速冪優化DP+KMP)

題意:

求長度為n的不含長為m的指定子串的字串的個數

1s, n<=1e9, m<=50

思路:

長見識了。。

設那個指定子串為s

f[i][j]表示長度為i的字串(其中後j個字元與s的前j個字元一致的情況下)的方法數

若匹配到s串長度為i的字尾加一個字元num可以組成最長長度為j的字尾,設a[i][j]為num的方法數

例如,s為12312,a為

9 1 0 0 0 0
8 1 1 0 0 0
8 1 0 1 0 0
9 0 0 0 1 0
8 1 0 0 0 1

(i,j都是從0到m-1)

如a[1][2]表示從“1”到“12”可以加的字元方法數,顯然加“2”才可以,所以a[1][2]=1

而a[2][0]表示從“12”到“”可以加的字元方法數:顯然不能加“3”,不然s串會匹配到"123";也不能加“1”,不然s串會匹配成"1"。所以a[2][0]=8

求a矩陣的方法是kmp,感覺只可意會(我寫不出來QAQ)

 

顯然f[i][x]只能由f[i-1][k]轉移而來,而k為多少,要看a陣列了

然後狀態轉移方程為:$f[i][j] = f[i-1][0]*a[0][j]+f[i-1][1]*a[1][j] +\dots + f[i-1][m-1]*a[m-1][j]$

這個狀態轉移方程可以用矩陣快速冪來加速

答案就是$\sum f[n][i]$

程式碼:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include
<cmath> #include<cstring> #include<string> #include<stack> #include<queue> #include<deque> #include<set> #include<vector> #include<map> #include<functional> #define fst first #define sc second #define pb push_back #define
mem(a,b) memset(a,b,sizeof(a)) #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lc root<<1 #define rc root<<1|1 #define lowbit(x) ((x)&(-x)) using namespace std; typedef double db; typedef long double ldb; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PI; typedef pair<ll,ll> PLL; const db eps = 1e-6; //const int mod = 1e9+7; const int maxn = 2e3+100; const int maxm = 2e6+100; const int inf = 0x3f3f3f3f; const db pi = acos(-1.0); int a[60][60]; int f[60][60]; int n, m, mod; char s[maxn]; int Next[maxn]; void mtpl(int a[60][60], int b[60][60], int s[60][60]){ int tmp[60][60]; for(int i = 0; i < m; i++){ for(int j = 0; j < m; j++){ tmp[i][j] = 0; for(int k = 0; k < m; k++){ tmp[i][j]+=a[i][k]*b[k][j]%mod; tmp[i][j]%=mod; } } } for(int i = 0; i < m; i++){ for(int j = 0; j < m; j++){ s[i][j] = tmp[i][j]; } } return; } void fp(int x){ while(x){ if(x&1)mtpl(f,a,f); mtpl(a,a,a); x>>=1; } return; } void kmp(){ int fix = 0; for(int i = 2; i <= m; i++){ while(fix && s[fix+1]!=s[i])fix=Next[fix]; if(s[fix+1]==s[i])++fix; Next[i]=fix; } for(int i = 0; i < m; i++){ for(char j = '0'; j <= '9'; j++){ fix = i; while(fix&&s[fix+1]!=j)fix=Next[fix]; if(j==s[fix+1])a[i][fix+1]++; else a[i][0]++; } } return; } int main(){ scanf("%d %d %d", &n, &m, &mod); scanf("%s", s+1); mem(a, 0); kmp(); mem(f,0); f[0][0]=1; fp(n); int ans = 0; for(int i = 0; i < m; i++){ ans += f[0][i]; ans%=mod; } printf("%d", ans); return 0; } /* 5 3 4 5 1 2 */