DNA Sequence POJ - 2778 AC自動機 矩陣快速冪
阿新 • • 發佈:2018-11-19
題解
給m個長度10以內的病毒串 問長度為n的主串且不匹配任意一個病毒串的有多少個
m最大10所以節點數不超過100 利用AC自動機建圖 建立鄰接矩陣表示從節點i到節點j能轉移的字元數量 除去字元結束節點和fail指標路徑上是結束節點 通過N個鄰接矩陣相乘即可得到i到j走N步的方案數 將0到i求和即為答案
因為N過大需要用矩陣快速冪求解
AC程式碼
#include <stdio.h>
#include <iostream>
#include <queue>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int MOD = 1e5;
const int MAXN = 102;
const int MAXC = 10;
int nxt[MAXN][MAXC], sed[MAXN], fal[MAXN], idx;
int vis[MAXN];
char s[MAXN];
struct Matix
{
ll m[MAXN][MAXN];
Matix()
{
memset(m, 0, sizeof(m));
}
Matix operator*(const Matix &m2){
Matix t;
for ( int i = 0; i <= idx; ++i) //不能寫小於MAXN
for (int j = 0; j <= idx; ++j)
for (int k = 0; k <= idx; ++k)
t.m[i][j] = (t.m[i][j] + m[i][k] * m2.m[k][j]) % MOD;
return t;
};
}g, u;
void Insert(char *s, int n) //插入一個字串s長度為n的模式串
{
int x = 0;
for (int i = 0; i < n; i++)
{
int c = s[i];
if (!nxt[x][c])
nxt[x][c] = ++idx;
x = nxt[x][c];
}
sed[x]++;
}
void Build() //建立失配指標資訊
{
queue<int> q; //需要先給每個節點的父節點建立失配資訊 類似廣搜
for (int i = 0; i < MAXC; i++)
if (nxt[0][i]) //先將根節點連線的有效節點入隊 不能從根節點出發
q.push(nxt[0][i]); //初始每個節點的失配節點都是根
while (!q.empty())
{
int f = q.front();
q.pop();
for (int i = 0; i < MAXC; i++)
if (nxt[f][i]) //存在子節點
fal[nxt[f][i]] = nxt[fal[f]][i], q.push(nxt[f][i]); //子節點失配嘗試匹配一次父節點失配指標的子節點
else //如果不存在
nxt[f][i] = nxt[fal[f]][i]; //則直接將這個節點設定為父節點失配節點的子節點
}
}
int Match(char *s, int n) //查詢字串s能夠匹配多少模式串
{
int x = 0, res = 0; //當前節點 查詢結果
for (int i = 0; i < n; i++)
{
int c = s[i] - 'A' + 1;
x = nxt[x][c]; //轉移到當前字元 如果失配會自動到失配指標
for (int p = x; p; p = fal[p])//已經被處理過了的節點不在繼續
res += sed[p]; //將以當前節點為結尾的所有子串全部加上 -1標記為已訪問
}
return res;
}
void DFS(int x) //建立圖
{
vis[x] = 1;
for (int i = 1; i <= 4; i++)
{
int flag = 1;
for (int p = nxt[x][i]; p; p = fal[p]) //將結束節點和失配連上有結束節點的除去
if (sed[p])
{
flag = 0;
break;
}
if (flag)
{
g.m[x][nxt[x][i]]++;
if (!vis[nxt[x][i]])
DFS(nxt[x][i]);
}
}
}
int main()
{
#ifdef LOCAL
freopen("C:/input.txt", "r", stdin);
#endif
int M, N;
cin >> M >> N;
for (int i = 0; i < M; i++)
{
scanf("%s", s);
int l = strlen(s);
for (int i = 0; i < l; i++)
if (s[i] == 'A')
s[i] = 1;
else if (s[i] == 'C')
s[i] = 2;
else if (s[i] == 'G')
s[i] = 3;
else if (s[i] == 'T')
s[i] = 4;
Insert(s, l);
}
Build();
DFS(0);
for (int i = 0; i < MAXN; i++)
u.m[i][i] = 1;
while (N)
{
if (N & 1)
u = u * g;
g = g * g;
N >>= 1;
}
ll ans = 0;
for (int i = 0; i < MAXN; i++)
ans = (ans + u.m[0][i]) % MOD;
cout << ans << endl;
return 0;
}