1. 程式人生 > >AC自動機--學習筆記

AC自動機--學習筆記

在學習AC自動機前,請確保你已經充分理解

我們將從這樣一個問題開始引入AC自動機 Q:給定n個模式串和1個文字串,求有多少個模式串在文字串中出現過

這個問題要怎麼解? 用N次KMP嗎,這樣顯然爆炸啊

於是閒著沒事幹腦袋又十分豐腴的科學家們有了一個奇妙的想法 在Trie上求KMP!(當然實際上只是類似KMP的nxt,定義還是有所不同的)

假設當前有5個模式串’she’, ‘he’, ‘say’, ‘shr’, ‘her’ 先建出他們的字典樹 在這裡插入圖片描述

建好字典樹後我們效仿KMP的nxt陣列 在Trie上增加fail失配指標

什麼是fail指標 假設當前結點uu所代表的串為SS,那麼uufai

lfail指標指向 最長的,能與SS的字尾匹配的Trie樹的字首結尾結點 (這都什麼 #$*&@%¥^#)

是不是有點被繞暈了,那就看這圖感性理解一下吧 在這裡插入圖片描述

比如最長的,能與串sh的字尾匹配的 Trie的字首,只有串h 以及最長的,能與串she的字尾匹配的 Trie的字首,只有串he

那麼這個fail指標要怎麼求呢 可以考慮用BFS實現 假設當前從隊首取出結點uu 對於uu的一個子節點ch[u][i]ch[u][i] 我們uu開始不斷沿著failfail指標向上跳 直到跳到一個結點vv也有表示字元ii的子節點ch[v][i]ch[v][i]

那麼ch[u][i]ch[u][i]failfail指標指向ch[v][i]ch[v][i]

特別的,如果一直跳到根都沒有符合條件的結點 那麼ch[u][i]ch[u][i]failfail指標指向根 以及注意所有第二層的結點failfail指標都指向根

void build_AC()
{
    for(int i=0;i<=25;++i)
    if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);//第二層節點fail都指向根
    while(!q.empty
()) { int u=q.front(); q.pop(); for(int i=0;i<=25;++i) { if(!ch[u][i]) continue;//沒有這個子節點就跳過 int tt=fail[u]; while(!ch[tt][i]&&tt) tt=fail[tt];//沿著fail指標找到第一個也有同樣子節點的結點 fail[ch[u][i]]=ch[tt][i]; q.push(ch[u][i]); } } }

現在連好了fail指標,匹配就簡單了

首先用一個指標指向根 將文字串一位一位送入自動機 若當前指標存在表示文字串下一位的子節點,令指標移向該子節點 否則沿著fail指標不斷轉移,直到跳到一個存在該子節點的結點,令指標移向該子節點

指標沒跳轉完成一次,就沿著fail指標統計一次

void query(char *ss,int len)
{
    int u=0;
    for(int i=0;i<len;++i)
    {
        int x=ss[i]-'a';
        while(!ch[u][x]&&u) u=fail[u];
        u=ch[u][x];
        
        for(int t=u;t&&sum[t]!=-1;t=fail[t])
        ans+=sum[t],sum[t]=-1;
    }
}

AC自動機の應用

上述問題的果題

#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;

int read()
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return x*f;
}

const int maxn=500010;
int Q,n,cnt;
char pat[maxn],txt[maxn<<1];
int ch[maxn][26],fail[maxn],sum[maxn];
queue<int> q;
int ans;

void ins(char *ss,int len)
{
    int u=0;
    for(int i=0;i<len;++i)
    {
        int x=ss[i]-'a';
        if(!ch[u][x]) ch[u][x]=++cnt;
        u=ch[u][x];
    }
    sum[u]++;
}

void build_AC()
{
    for(int i=0;i<=25;++i)
    if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);
    while(!q.empty())
    {
        int u=q.front(); q.pop();
        for(int i=0;i<=25;++i)
        {
            if(!ch[u][i]) continue;
            int tt=fail[u];
            while(!ch[tt][i]&&tt) tt=fail[tt];
            fail[ch[u][i]]=ch[tt][i];
            q.push(ch[u][i]);
        }
    }
}

void query(char *ss,int len)
{
    int u=0;
    for(int i=0;i<len;++i)
    {
        int x=ss[i]-'a';
        while(!ch[u][x]&&u) u=fail[u];
        u=ch[u][x];
        
        for(int t=u;t&&sum[t]!=-1;t=fail[t])
        ans+=sum[t],sum[t]=-1;
    }
}

void init()
{
	ans=cnt=0;
	memset(sum,0,sizeof(sum));
	memset(ch,0,sizeof(ch));
}

int main()
{
    Q=read();
    while(Q--)
    {
    	n=read(); init();
    	for(int i=1;i<=n;++i)
    	{
        	scanf("%s",&pat);
        	ins(pat,strlen(pat));
    	}
    	scanf("%s",&txt);
    
    	build_AC(); query(txt,strlen(txt));
    	printf("%d\n",ans);
	}
    return 0;
}

Q:有N個由小寫字母組成的模式串以及一個文字串T。每個模式串可能會在文字串中出現多次。你需要找出哪些模式串在文字串T中出現的次數最多。

也是稍作修改即可的果題

#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;

int read()
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return x*f;
}

const int maxn=50010;
int n;
char pt[200][100],txt[maxn*20];
int ch[maxn][26],fail[maxn],cnt;
int id[maxn],num[200];
queue<int> q;
int ans;

void ins(char *ss,int len,int k)
{
    int u=0;
    for(int i=0;i<len;++i)
    {
        int x=ss[i]-'a';
        if(!ch[u][x]) ch[u][x]=++cnt;
        u=ch[u][x];
    }
    id[u]=k;
}

void build_AC()
{
    for(int i=0;i<=25;++i)
    if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);
    while(!q.empty())
    {
        int u=q.front(); q.pop();
        for(int i=0;i<=25;++i)
        {
            if(!ch[u][i]) continue;
            int tt=fail[u];
            while(!ch[tt][i]&&tt) tt=fail[tt];
            fail[ch[u][i]]=ch[tt][i];
            q.push(ch[u][i]);
        }
    }
}

void query(char *ss,int len)
{
    int u=0;
    for(int i=0;i<len;++i)
    {
        int x=ss[i]-'a';
        while(!ch[u][x]&&u) u=fail[u];
        u=ch[u][x];
        for(int t=u;t;t=fail[t])
        num[id[t]]++;
    }
    for(int i=1;i<=n;++i) 
    ans=max(ans,num[i]);
}

void init()
{
    ans=cnt=0;
    memset(ch,0,sizeof(ch));
    memset(id,0,sizeof(id));
    memset(num,0,sizeof(num));
}

int main()
{
    while(scanf("%d",&n)!=EOF)
    {
        if(n==0) break; init();
        for(int i=1;i<=n;++i)
        {
            scanf("%s",&pt[i]);
            ins(pt[i],strlen(pt[i]),i);
        }
        scanf("%s",&txt);
        
        build_AC(); query(txt,strlen(txt));
        printf("%d\n",ans);
        for(int i=1;i<=n;++i)
        if(num[i]==ans) printf("%s\n",pt[i]);
    }
    return 0;
}