1. 程式人生 > >【HNOI2016】—找相同字元(字尾自動機)

【HNOI2016】—找相同字元(字尾自動機)

傳送門

一道不錯的題

雖然不知道為什麼一群 D a l a o Dalao 要用廣義 S

a m Sam

S a m Sam 隨便搞啊

首先我們可以先對一個串建 S

a m Sam

用另一個串在上面跑

我們考慮對於第二個串的每一個結尾對答案的貢獻數

顯然是 P a r e n

t T r e e Parent-Tree 上當前 S a m Sam 所在的點到的根的所有子串的大小
當然這只是大概的一個說法,實際上不準確,應該是到根的所有點集合內的子串的個數(包括不同地方的不同串每個都算一個)之和,感性理解到意思就可以了

現在考慮如何維護這樣一個東西

顯然我們可以先計算出每個 e n d p o s endpos 集合 e n d p o s endpos 大小

然後乘以這個集合內有的串的個數,也就是 l e n [ u ] l e n [ l i n k [ u ] ] len[u]-len[link[u]]

因為我們也不可能一直向根跑來統計答案

所以我們利用字首和的思想,把計算出的值往下傳

這樣最後每個點記錄的就都是其到根的值了

但是注意第二個串匹配的時候不能直接加上當前這個點的值

應該是他 l i n k link 的值加上在這個集合的子串大小乘上 e n d p o s endpos 集合

所以記錄一下當前最長匹配的串長就可以了

具體看程式碼更好理解一些

#include<bits/stdc++.h>
using namespace std;
#define ll long long 
inline int read(){
    char ch=getchar();
    int res=0;
    while(!isdigit(ch))ch=getchar();
    while(isdigit(ch))res=(res<<3)+(res<<1)+(ch^48),ch=getchar();
    return res;
}
const int N=400005;
int nxt[N][27],len[N],link[N],siz[N],A[N],p[N],ksiz[N],tot,last;
char a[N];
ll ans;
inline void sa_extend(int c){
    int cur=++tot,p=last;last=tot;
    len[cur]=len[p]+1,siz[cur]=1;
    for(;p&&!nxt[p][c];p=link[p])nxt[p][c]=cur;
    if(!p)link[cur]=1;
    else{
        int q=nxt[p][c];
        if(len[q]==len[p]+1)link[cur]=q;
        else{
            int clo=++tot;
            memcpy(nxt[clo],nxt[q],sizeof(nxt[q]));
            link[clo]=link[q],len[clo]=len[p]+1;
            for(;p&&nxt[p][c]==q;p=link[p])nxt[p][c]=clo;
            link[cur]=link[q]=clo;
        }
    }
}
int main(){
    scanf("%s",a+1);last=tot=1;
    int lent=strlen(a+1);
    for(int i=1;i<=lent;++i)sa_extend(a[i]-'a');
    for(int i=1;i<=tot;i++)A[len[i]]++;
    for(int i=1;i<=tot;i++)A[i]+=A[i-1];
    for(int i=1;i<=tot;i++)p[A[len[i]]--]=i;
    for(int i=tot;i;--i)siz[link[p[i]]]+=siz[p[i]];
    for(int i=1;i<=tot;i++)
        ksiz[p[i]]=ksiz[link[p[i]]]+(len[p[i]]-len[link[p[i]]])*siz[p[i]];
    scanf("%s",a+1);
    lent=strlen(a+1);
    int p=1,let=0;
    for(int i=1;i<=lent;i++){
        int c=a[i]-'a';
        for(;p&&!nxt[p][c];p=link[p]);
        if(!p)let=0,p=1;
        else{
            let=min(len[p],let)+1,p=nxt[p][c];
            ans+=ksiz[link[p]]+(let-len[link[p]])*siz[p];
        }
    }
    cout<<ans;
}