1. 程式人生 > >【XSY3320】string AC自動機 雜湊 點分治

【XSY3320】string AC自動機 雜湊 點分治

題目大意

  給一棵樹,每條邊上有一個字元,求有多少對 \((x,y)(x<y)\),滿足 \(x\)\(y\) 路徑上的邊上的字元按順序組成的字串為迴文串。

  \(1\leq n\leq 50000,1\leq x_i,y_i\leq n,z_i\in\{0,1\}\)

題解

  觀察一條經過重心的迴文串是長什麼樣的

  \(S\) 是一個任意的字串,\(T\) 是一個迴文串。

  建出根到每個節點對應的串的AC自動機。

  那麼 \(x\) 這邊的 \(S\) 串就是 \(x\) 對應的AC自動機節點的一個字尾, \(T\) 串是一個字首。

  dfs 整棵樹的 fail 樹,先統計每個點作為 \(x\)

點的貢獻,再把作為 \(y\) 點的貢獻加到資料結構中。

  開 \(\sqrt n\) 個長度為 \(\sqrt n\) 的陣列 \(c_{1,\sqrt n}\)\(c_{i,j}\) 表示當前節點有多少個長度 \(\bmod i=j\) 的祖先。

  當一個點是 \(y\) 點的時候,令對應長度的字串的出現次數 \(+1\),還要對於 \(\leq \sqrt n\) 的所有數 \(i\),令 \(c_{i,\lvert S \rvert \bmod i}++\)

  當一個點是 \(x\) 點的時候,一個迴文串的所有迴文字首可以被表示為 \(O(\log n)\) 個等差數列,公差 \(\leq \sqrt n\)

的那部分在 \(c\) 裡面查,剩下的暴力查就好了。

  記一個等差數列的首項為 \(a_1\),公差為 \(d\),末項為 \(a_n\),那麼貢獻就是 dfs 到深度為 \(a_n\) 的點時 \(c_{d,a_1\bmod d}\) 的值減掉 dfs 到深度為 \(a_1-d\) 的點時 \(c_{d,a_1\bmod d}\) 的值。

  先 dfs 一遍把所有詢問的資訊插到 vector 中,再 dfs 一遍計算答案。

  求一個串的所有迴文字首可以直接雜湊。

  時間複雜度:\(f(n)=O(n^\frac{3}{2})+O(n\log^2 n)=O(n^\frac{3}{2})\)

  \(T(n)=2T(\frac{n}{2})+f(n)=2T(\frac{n}{2})+O(n^\frac{3}{2})=O(n^\frac{3}{2})\)

程式碼

  把這份程式碼中的字尾自動機換成 AC自動機,迴文自動機換成雜湊就好了。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<functional>
#include<cmath>
#include<vector>
#include<queue>
#include<assert.h>
//using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
using std::vector;
using std::queue;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
    char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
void open2(const char *s){
#ifdef DEBUG
    char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const int N=50010;
vector<pii> g[N];
int sz[N];
int totsz,rt,rtsz;
int b[N];
int n;
int f[N];
ll* ss[N];
ll ss2[N];
ll ans=0;
int _log[N];
struct info
{
    int x;
    int y;
    int z;
    info(int a=0,int b=0,int c=0):x(a),y(b),z(c){}
};
int cmp(info a,info b)
{
    if(a.x!=b.x)
        return a.x<b.x;
    return a.z<b.z;
}
void dfs1(int x,int fa)
{
    sz[x]=1;
    for(auto v:g[x])
        if(v.first!=fa&&!b[v.first])
        {
            dfs1(v.first,x);
            sz[x]+=sz[v.first];
        }
}
void dfs2(int x,int fa)
{
    int mx=totsz-sz[x];
    for(auto v:g[x])
        if(v.first!=fa&&!b[v.first])
        {
            dfs2(v.first,x);
            mx=max(mx,sz[v.first]);
        }
    if(mx<rtsz)
    {
        rtsz=mx;
        rt=x;
    }
}
void dfs3(int x,int fa)
{
    f[x]=fa;
    for(auto v:g[x])
        if(v.first!=fa&&!b[v.first])
            dfs3(v.first,x);
}
int tot;
int str[N];
namespace sam
{
    int next[2*N][2];
    int fail[2*N];
    int len[2*N];
    int last,cnt;
    int b[2*N];
    int a[2*N][2];
    int s[2*N]; 
    void init()
    {
        while(cnt)
        {
            next[cnt][0]=next[cnt][1]=0;
            a[cnt][0]=a[cnt][1]=0;
            b[cnt]=0;
            s[cnt]=0;
            cnt--;
        }
        cnt=1;
        last=1;
    }
    int insert(int p,int c)
    {
        if(next[p][c])
        {
            last=next[p][c];
            s[last]++;
            return last;
        }
//      int p=last;
        int np=++cnt;
        len[np]=len[p]+1;
        s[np]=1;
        for(;p&&!next[p][c];p=fail[p])
            next[p][c]=np;
        if(!p)
            fail[np]=1;
        else
        {
            int q=next[p][c];
            if(len[q]==len[p]+1)
                fail[np]=q;
            else
            {
                int nq=++cnt;
                len[nq]=len[p]+1;
                memcpy(next[nq],next[q],sizeof next[q]);
                fail[nq]=fail[q];
                fail[q]=fail[np]=nq;
                for(;p&&next[p][c]==q;p=fail[p])
                    next[p][c]=nq;
            }
        }
        return last=np;
    }
}
namespace pam
{
    int next[N][2];
    int trans[N][2];
    int fail[N];
    int len[N];
    int diff[N];
    int link[N];
    int top[N];
    int last;
    int cnt;
    void init()
    {
        while(cnt>=0)
        {
            next[cnt][0]=next[cnt][1]=0;
            trans[cnt][0]=trans[cnt][1]=0;
            cnt--;
        }
        cnt=1;
        str[0]=-1;
        fail[0]=1;
        fail[1]=0;
        len[0]=0;
        len[1]=-1;
        last=0;
        link[0]=0;
        diff[0]=1;
        diff[1]=0;
        top[0]=0;
        top[1]=1;
        trans[0][0]=trans[0][1]=trans[1][0]=trans[1][1]=1;
    }
    int find(int x,int c)
    {
        return str[tot-len[x]-1]==c?x:trans[x][c];
    }
    void insert(int c)
    {
        str[++tot]=c;
        last=find(last,c);
        int now=last;
        if(!next[now][c])
        {
            int cur=++cnt;
            len[cur]=len[now]+2;
            last=find(fail[last],c);
            fail[cur]=next[last][c];
            diff[cur]=len[cur]-len[fail[cur]];
            if(diff[cur]==diff[fail[cur]])
            {
                link[cur]=link[fail[cur]];
                top[cur]=top[fail[cur]];
            }
            else
            {
                link[cur]=fail[cur];
                top[cur]=cur;
            }
            if(!link[cur])
                link[cur]=cur;
            memcpy(trans[cur],trans[fail[cur]],sizeof trans[cur]);
            trans[cur][str[tot-len[fail[cur]]]]=fail[cur];
            next[now][c]=cur;
        }
        last=next[now][c];
    }
}
namespace trie
{
    int a[N][2];
    int s[N];
    int cnt;
    void clear()
    {
        while(cnt)
        {
            a[cnt][0]=a[cnt][1]=0;
            s[cnt]=0;
            cnt--;
        }
        cnt=1;
    }
}
ll s,s2;
int pos[N];
int pos2[N];
int pos3[N];
int pos4[N];
int q[N];
int len[N],id[N],top;
int head,tail;
vector<int> e[2*N];
int sq;
vector<info> h[2*N];
int orzzjt,orzzjt2;
void bfs(int x)
{
    sam::init();
//  sam::s[1]=1;
    pos[x]=1;
    head=1;
    tail=0;
    q[++tail]=x;
    trie::clear();
    pos4[x]=1;
    while(tail>=head)
    {
        int y=q[head++];
        s+=trie::s[pos4[y]];
        trie::s[pos4[y]]++;
        for(auto v:g[y])
            if(!b[v.first]&&v.first!=f[y])
            {
                pos[v.first]=sam::insert(pos[y],v.second);
                q[++tail]=v.first;
                if(trie::a[pos4[y]][v.second])
                    pos4[v.first]=trie::a[pos4[y]][v.second];
                else
                    pos4[v.first]=trie::a[pos4[y]][v.second]=++trie::cnt;
            }
    }
}
void dfs(int x,int fa)
{
    for(int y=pos[x];y!=1&&!sam::b[y];y=sam::fail[y])
    {
        sam::a[sam::fail[y]][str[tot-sam::len[sam::fail[y]]]]=y;
        sam::b[y]=1;
    }
    //這樣建出來的字尾樹不是完整的,但已經夠用了 
    
    int now=pam::last;
    pos2[x]=now;
    if(pam::len[now]==tot)
    {
        if(fa)
            s2++;
        pos3[x]=now;
    }
    else
        pos3[x]=pos3[fa];
    for(auto v:g[x])
        if(!b[v.first]&&v.first!=fa)
        {
            pam::last=now;
            pam::insert(v.second);
            dfs(v.first,x);
            tot--;
        }
}
void dfs4(int x)
{
    len[++top]=sam::len[x];
    id[top]=x;
    for(auto v:e[x])
        for(int y=pos3[v];y>1;)
            if(pam::diff[y]<=sq)
            {
                h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y]-pam::diff[y])-len]].push_back(info(sam::len[x]-pam::len[y]-pam::diff[y],pam::diff[y],-1));
                h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]])-len]].push_back(info(sam::len[x]-pam::len[pam::link[y]],pam::diff[y],1));
                //h.push_back(info(sam::len[x]-pam::len[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y])-len],1));
//              h.push_back(info(sam::len[x]-pam::len[pam::link[y]]+pam::diff[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]]+pam::diff[y])-len],-1));
                y=pam::fail[pam::link[y]];
                orzzjt2+=_log[top];
            }
            else
            {
                y=pam::fail[y];
            }
    if(sam::a[x][0])
        dfs4(sam::a[x][0]);
    if(sam::a[x][1])
        dfs4(sam::a[x][1]);
    top--;
}
void dfs5(int x)
{
    for(auto v:h[x])
        if(v.x>=0&&v.x!=sam::len[x])
            s+=ss[v.y][v.x%v.y]*v.z;
    orzzjt+=sq;
    for(int i=1;i<=sq;i++)
        ss[i][sam::len[x]%i]+=sam::s[x];
    ss2[sam::len[x]]+=sam::s[x];
    
    
    for(auto v:h[x])
        if(v.x>=0&&v.x==sam::len[x])
            s+=ss[v.y][v.x%v.y]*v.z;
            
            
    for(auto v:e[x])
        for(int y=pos3[v];y>1;)
            if(pam::diff[y]<=sq)
            {
                y=pam::fail[pam::link[y]];
            }
            else
            {
                s+=ss2[sam::len[x]-pam::len[y]];
                y=pam::fail[y];
            }
            
    if(sam::a[x][0])
        dfs5(sam::a[x][0]);
    if(sam::a[x][1])
        dfs5(sam::a[x][1]);
    
        
    for(int i=1;i<=sq;i++)
        ss[i][sam::len[x]%i]-=sam::s[x];
    ss2[sam::len[x]]-=sam::s[x];
}
ll calc(int x)
{
    s=0;
    s2=0;
    bfs(x);
    pam::init();
    dfs(x,0);
    for(int i=1;i<=sam::cnt;i++)
    {
        e[i].clear();
        h[i].clear();
    }
    for(int i=1;i<=tail;i++)
        e[pos[q[i]]].push_back(q[i]);
    dfs4(1);
//  for(int i=1;i<=sam::cnt;i++)
//      sort(h[i].begin(),h[i].end());
    dfs5(1);
    return s;
}
int c[N],c2[N];
int t;
vector<pii> g2;
void solve(int x)
{
    dfs1(x,0);
    totsz=sz[x];
    rtsz=0x7fffffff;
    dfs2(x,0);
    x=rt;
    dfs3(x,0);
    int t=0;
    sq=sqrt(totsz);
//  sq=0;
    ans+=calc(x);
    ans+=s2;
    for(auto v:g[x])
        if(!b[v.first])
        {
            b[v.first]=1;
            c[++t]=v.first;
            c2[t]=v.second;
        }
    g2=g[x];
    g[x].clear();
    for(int i=1;i<=t;i++)
    {
        b[c[i]]=0;
        g[x].clear();
        g[x].push_back(pii(c[i],c2[i]));
        ans-=calc(x);
        b[c[i]]=1;
    }
    g[x]=g2;
    for(int i=1;i<=t;i++)
        b[c[i]]=0;
    b[x]=1;
    for(auto v:g[x])
        if(!b[v.first])
            solve(v.first);
}
int main()
{
    open("string");
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        for(int j=1,k=0;j<=n;j<<=1,k++)
            _log[i]=k;
    int _sqrt=sqrt(n);
    for(int i=1;i<=_sqrt;i++)
    {
        ss[i]=new ll[i];
        for(int j=0;j<i;j++)
            ss[i][j]=0;
    }
    int x,y,z;
    for(int i=1;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        g[x].push_back(pii(y,z));
        g[y].push_back(pii(x,z));
    }
    solve(1);
//  assert(ans%2==0);
//  ans/=2;
    printf("%lld\n",ans);
//  printf("%d\n",orzzjt);
//  printf("%d\n",orzzjt2);
    return 0;
}