1. 程式人生 > >MemSQL Start[c]UP 2.0 - Round 1 E - Three strings 廣義字尾自動機

MemSQL Start[c]UP 2.0 - Round 1 E - Three strings 廣義字尾自動機

E - Three strings

將三個串加進去,看每個節點在三個串中分別出現了多少次。

#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mk make_pair
#define PII pair<int, int>
#define PLI pair<LL, int>
#define ull unsigned long long
using namespace std;

const int N = 5e5 + 7
; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; int n, ans[N], len[3]; char s[3][N]; struct SuffixAutomaton { int cur, cnt, ch[N<<1][26], id[N<<1], fa[N<<1], dis[N<<1], sz[N<<1], c[N];
int num[3][N<<1]; SuffixAutomaton() {cur = cnt = 1;} void init() { for(int i = 1; i <= cnt; i++) { memset(ch[i], 0, sizeof(ch[i])); sz[i] = c[i] = dis[i] = fa[i] = 0; } cur = cnt = 1; } int extend(int p, int c) { cur
= ++cnt; dis[cur] = dis[p]+1; for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = cur; if(!p) fa[cur] = 1; else { int q = ch[p][c]; if(dis[q] == dis[p]+1) fa[cur] = q; else { int nt = ++cnt; dis[nt] = dis[p]+1; memcpy(ch[nt], ch[q], sizeof(ch[q])); fa[nt] = fa[q]; fa[q] = fa[cur] = nt; for(; ch[p][c]==q; p=fa[p]) ch[p][c] = nt; } } sz[cur] = 1; return cur; } void topo(int n) { for(int i = 1; i <= cnt; i++) c[dis[i]]++; for(int i = 1; i <= n; i++) c[i] += c[i-1]; for(int i = cnt; i >= 1; i--) id[c[dis[i]]--] = i; } void solve() { for(int i = 0; i < 3; i++) { scanf("%s", s[i]); len[i] = strlen(s[i]); for(int j = 0, last = 1; j < len[i]; j++) last = extend(last, s[i][j]-'a'); } for(int i = 0; i < 3; i++) { for(int j = 0, p = 1; j < len[i]; j++) { p = ch[p][s[i][j]-'a']; num[i][p]++; } } topo(max(len[0], max(len[1], len[2]))); for(int i = cnt; i >= 1; i--) for(int j = 0; j < 3; j++) num[j][fa[id[i]]] += num[j][id[i]]; for(int i = 2; i <= cnt; i++) { int ret = 1ll*num[0][i]*num[1][i]%mod*num[2][i]%mod; int mx = dis[i], mn = dis[fa[i]]+1; ans[mn] = (ans[mn] + ret) % mod; ans[mx+1] = (ans[mx+1]-ret+mod)%mod; } int Len = min(len[0], min(len[1], len[2])); for(int i = 1; i <= Len; i++) { ans[i] = (ans[i] + ans[i-1]) % mod; printf("%d ", ans[i]); } puts(""); } } sam; int main() { sam.solve(); return 0; } /* */