1. 程式人生 > >Wannafly挑戰賽10 F.小H和遺蹟

Wannafly挑戰賽10 F.小H和遺蹟

時間限制:C/C++ 1秒,其他語言2秒
空間限制:C/C++ 262144K,其他語言524288K
64bit IO Format: %lld

題目描述

    小H在一片遺蹟中發現了一種古老而神祕的文字,這種文字也由26種字母組成,小H用小寫字母來代替它們。遺蹟裡總共有N句話,由於年代久遠,每句話至少有一處無法辨識,用#表示,缺失的可能是任意長度(也包括0)的任意字串。
    小H發現這些話非常相似,現在小H想知道,有多少對句子可能是相同的
    注意:(x,x)這樣的句子對不計入答案,(x,y),(y,x)視為同一個句子對(詳見樣例)

輸入描述:

第1行,一個整數N
第2~N+1行,每行一個字串表示一句話
2≤N≤500,000,所有字串的總長不超過1,000,000

輸出描述:

一行,一個整數,表示你給出的答案

思路:

    首先貼上一段題解:

    定義兩個字串A,B“前相似”,當且僅當A是B的字首或B是A的字首

    定義兩個字串A,B“後相似”,當且僅當A是B的字首或B是A的字尾

    先證明結論:兩個字串可能相同,當且僅當A和B在第一個#之前的部分前相似,並且A和B在最後一個#之後的部分後相似

    證明:結論的必要性顯然,下證充分性

    1.首先可以將A和B在第一個#之前的部分、A和B在最後一個#之後的部分都去掉而不影響結果,原因是:不妨設A在第一個#之前的部分是B在第一個#之前的部分的字首,則他們都可以變成B#的形式,後面的部分同理

    2.然後若A,B中有一個是單獨的#,則顯然成立

    3.否則A開頭必然是#X#的形式(X是任意字串),同理B開頭是#Y#,則可以將#X和#Y去掉而不影響結果,因為他們都可以變成XY#的形式,這就可以轉第2步遞迴構造,得證

    有了這個結論,我們可以把每個串在第一個#之前的部分和在最後一個#之後的部分分別插入到兩棵Trie樹中,兩個串可能相同當且僅當他們對應的節點在兩棵樹上都是祖孫關係

    我們可以通過在第一棵Trie樹上進行一次DFS求解,每到一個節點就先統計它在另一棵樹上對應的點的祖先和子樹上的所有點的和並計入答案,再將它在另一棵樹上對應的點加一,整個過程可以用DFS序+兩個樹狀陣列實現,時間複雜度O((Σlen)*log(Σlen))

       這種前後綴插入兩個Trie的思路算是一個很常見的套路了。之前大多是用一個字串前後綴在相應Trie中dfs序作為一對座標,然後計算一個矩形區域內點的個數。本題貢獻的計算與之前有所不同,下面解釋下題解中最後計算貢獻的部分:

       我們先對字尾Trie做一個dfs序,然後對字首Triedfs+回溯 計算貢獻。對於字首Trie上一個節點u,我們先找到其代表的字串在後綴Trie上對應的節點v,這樣貢獻可以分為兩部分計算,一部分是u的祖先集與v的後代集的交,另一部分是u的祖先集與v的祖先集的交,這樣統計可以做到不重不漏。我們用兩個樹狀陣列分別計算這兩部分貢獻,每次對第一個樹狀陣列L[v]的位置+1,這樣當我們的遍歷到u的時候,區間L[v]~R[v]的和就是第一部分貢獻。第二個樹狀陣列L[v]位置+1,R[v]+1位置-1,由於dfs序中v的祖先節點在v前,而v的兄弟節點在其區間中+1和-1相抵消,所以sum(L[v])即為第二部分貢獻。

AC程式碼:

#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>
#include <map>
#include <set>
#include <queue>
#include <cstdio>
#include <cmath>
using namespace std;
typedef long long LL;
const int N=1e6+5;

struct Trie {
	int cnt,ch[N][26];

	void init() {
		cnt=0;
		//memset(ch,0,sizeof(ch));
	}

	int insert(char *s,int op) {
		int l=strlen(s);
		int u=0,t=(op==1)?0:l-1;
		for(int i=t;;i+=op) {
			if(s[i]=='#') break;
			int c=s[i]-'a';
			if(!ch[u][c]) {
				ch[u][c]=++cnt;
			}
			u=ch[u][c];
		}
		return u;
	}
}pre,suf;

struct BIT {
	int m,bit[N];

	void init(int k) {
		m=k;
		//memset(bit,0,sizeof(bit));
	}

	void add(int i,int x) {
		for(;i<=m;i+=i&-i) {
			bit[i]+=x;
		}
	}

	int sum(int i) {
		int s=0;
		for(;i>0;i-=i&-i) {
			s+=bit[i];
		}
		return s;
	}
}b1,b2;

int n,u,v;
int pos,L[N],R[N];
char s[N];
vector<int> vc[N];
LL ans;

void dfs(int u) {
	L[u]=++pos;
	for(int i=0;i<26;i++) {
		int v=suf.ch[u][i];
		if(v) dfs(v);
	}
	R[u]=pos;
}

void cal(int u) {
	int v;
	for(int i=0;i<vc[u].size();i++) {
		v=vc[u][i];
		ans+=b1.sum(R[v])-b1.sum(L[v])+b2.sum(L[v]);
		b1.add(L[v],1);
		b2.add(L[v],1);
		b2.add(R[v]+1,-1);
	}
	for(int i=0;i<26;i++) {
		v=pre.ch[u][i];
		if(v) cal(v);
	}
	for(int i=0;i<vc[u].size();i++) {
		v=vc[u][i];
		b1.add(L[v],-1);
		b2.add(L[v],-1);
		b2.add(R[v]+1,1);
	}
}

void solve() {
	ans=0;pos=0;
	dfs(0);
	b1.init(pos);
	b2.init(pos);
	cal(0);
}

int main() {
	scanf("%d",&n);
	pre.init();
	suf.init();
	for(int i=1;i<=n;i++) {
		scanf("%s",s);
		u=pre.insert(s,1);
		v=suf.insert(s,-1);
		vc[u].push_back(v);
	}
	solve();
	printf("%lld\n",ans);
}