BZOJ-4556 找相同字元(廣義字尾自動機)
阿新 • • 發佈:2018-12-11
給定兩個字串,求出在兩個字串中各取出一個子串使得這兩個子串相同的方案數。兩個方案不同當且僅當這兩 個子串中有一個位置不同。 Input 兩行,兩個字串s1,s2,長度分別為n1,n2。1 <=n1, n2<= 200000,字串中只有小寫字母
Output 輸出一個整數表示答案
說實話,在我聽到“廣義字尾自動機”的時候很懵逼。 其實到現在我都沒有徹底搞懂這是什麼玩意兒。
看了一下部落格,我的理解就是一種重用,兩個串同時建進自動機的時候,讓兩個串中的公共部分共用節點,以達到節省空間的目的。
但是…我空間要開多大呢?極端情況下,兩個串毫無重複,無法共用,也就是說我還是得開兩倍的空間?
又看了網上大佬的程式碼,竟然跟普通的字尾自動機一模一樣?只是多了一個last = root的操作?
那這樣建有什麼好處呢?
Sam的學習太難受了…資料太少了。
照著自己的想法打了一遍,ac了這題。
在add函式開頭加上:
if(next[last][c] && step[last] + 1 == step[next[last][c]]) {
last = next[last][c];
return;
}
達到“已有狀態重用”的目的,然後再記錄每個串裡當前狀態的出現次數,自底向上更新。
最後用字符集大小 * a串中的取法 * b串中的取法,累加得到答案。
這真的是廣義字尾自動機嗎…我看不出跟開兩個sam有什麼區別?
ac程式碼:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 100005;
char s[maxn];
struct Sam {
int next[maxn << 2][26];
int link[maxn << 2], step[maxn << 2];
int a[maxn << 1], b[maxn << 2];
ll cov[2][maxn << 2];
int sz, last, root, len;
void init() {
sz = last = root = 1;
}
void add(int c, int m) {
if(next[last][c] && step[last] + 1 == step[next[last][c]]) {
last = next[last][c];
return;
}
int p = last;
int np = ++sz;
last = np;
step[np] = step[p] + 1;
while(!next[p][c] && p) {
next[p][c] = np;
p = link[p];
}
if(p == 0) {
link[np] = root;
} else {
int q = next[p][c];
if(step[p] + 1 == step[q]) {
link[np] = q;
} else {
int clone = ++sz;
memcpy(next[clone], next[q], sizeof(next[q]));
step[clone] = step[p] + 1;
link[clone] = link[q];
link[q] = link[np] = clone;
while(next[p][c] == q && p) {
next[p][c] = clone;
p = link[p];
}
}
}
}
void build() {
init();
scanf("%s", s);
len = strlen(s);
for(int i = 0; i < len; i++) {
add(s[i] - 'a', 0);
cov[0][last]++;
}
scanf("%s", s);
len = strlen(s);
last = root;
for(int i = 0; i < len; i++) {
add(s[i] - 'a', 1);
cov[1][last]++;
}
for(int i = 1; i <= sz; i++) {
a[step[i]]++;
}
for(int i = 1; i <= len; i++) {
a[i] += a[i - 1];
}
for(int i = 1; i <= sz; i++) {
b[a[step[i]]--] = i;
}
for(int i = sz; i > 1; i--) {
int e = b[i];
cov[0][link[e]] += cov[0][e];
cov[1][link[e]] += cov[1][e];
}
}
void solve() {
build();
ll ans = 0;
for(int i = 1; i <= sz; i++) {
ans += (step[i] - step[link[i]]) * cov[0][i] * cov[1][i];
}
printf("%lld\n", ans);
}
} sam;
int main() {
sam.solve();
return 0;
}