1. 程式人生 > >【BZOJ4566】找相同字元【字尾自動機】

【BZOJ4566】找相同字元【字尾自動機】

題意

 給定兩個字串,求兩個字串相同子串的方案數。

分析

 那麼將字串s1建SAM,然後對於s2的每個字首,都在SAM中找出來,並且計數就行。

 我一開始的做法是,建一個u和len,順著s2跑SAM,當st[u].next[c]存在的時候,u=st[u].next[c],len++,這時候找到了這個字首的最長公共字尾,然後順著parent邊向上走,然後res+=cnt[u]*(len-st[st[u].link].len)。為什麼是len-st[st[u].link].len。因為對於狀態u,它的有效長度是[st[st[u].link].len+1,st[u].len]。但是這樣寫完以後TLE了。然後我就去看了下大佬們的做法。思路也是一樣的只是記錄一個f陣列。

  

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 #include <iostream>
  5 
  6 using namespace std;
  7 const int maxn=200000+100;
  8 typedef long long LL;
  9 struct state{
 10     int len,link;
 11     int next[26];
 12
}st[2*maxn]; 13 int cnt[2*maxn],c[2*maxn],ap[2*maxn]; 14 LL f[2*maxn]; 15 char s1[maxn],s2[maxn]; 16 int n1,n2; 17 int last,cur,sz; 18 void init(){ 19 sz=1; 20 last=cur=0; 21 st[0].link=-1; 22 st[0].len=0; 23 } 24 25 void build_sam(int c){ 26 cur=sz++; 27
cnt[cur]=1; 28 st[cur].len=st[last].len+1; 29 int p; 30 for(p=last;p!=-1&&st[p].next[c]==0;p=st[p].link) 31 st[p].next[c]=cur; 32 if(p==-1) 33 st[cur].link=0; 34 else{ 35 int q=st[p].next[c]; 36 if(st[q].len==st[p].len+1) 37 st[cur].link=q; 38 else{ 39 int clone=sz++; 40 st[clone].len=st[p].len+1; 41 st[clone].link=st[q].link; 42 for(int i=0;i<26;i++) 43 st[clone].next[i]=st[q].next[i]; 44 for(;p!=-1&&st[p].next[c]==q;p=st[p].link) 45 st[p].next[c]=clone; 46 st[cur].link=st[q].link=clone; 47 } 48 } 49 last=cur; 50 } 51 int cmp(int a,int b){ 52 return st[a].len>st[b].len; 53 } 54 55 LL update(int u,int len){ 56 LL res=0; 57 while(u){ 58 res+=(LL)(len-st[st[u].link].len)*cnt[u]; 59 u=st[u].link,len=st[u].len; 60 } 61 return res; 62 } 63 64 int main(){ 65 scanf("%s%s",s1,s2); 66 n1=strlen(s1),n2=strlen(s2); 67 init(); 68 for(int i=0;i<n1;i++){ 69 build_sam(s1[i]-'a'); 70 } 71 for(int i=0;i<sz;i++) 72 c[i]=i; 73 sort(c,c+sz,cmp); 74 for(int i=0;i<sz;i++){ 75 int o=c[i]; 76 if(st[o].link!=-1) 77 cnt[st[o].link]+=cnt[o]; 78 } 79 80 LL ans=0; 81 int u=0,len=0; 82 for(int i=0;i<n2;i++){ 83 int c=s2[i]-'a'; 84 while(u!=-1&&st[u].next[c]==0) 85 u=st[u].link,len=st[u].len; 86 if(u==-1) 87 u=0,len=0; 88 else{ 89 u=st[u].next[c],len++; 90 // ans+=update(u,len); 91 ap[u]++,ans+=(LL)cnt[u]*(len-st[st[u].link].len); 92 } 93 } 94 95 for(int i=0;i<sz;i++){ 96 int o=c[i]; 97 if(st[o].link!=-1) 98 f[st[o].link]+=f[o]+ap[o]; 99 } 100 for(int i=1;i<sz;i++){ 101 ans+=(LL)cnt[i]*f[i]*(st[i].len-st[st[i].link].len); 102 } 103 printf("%lld\n",ans); 104 return 0; 105 }
View Code