1. 程式人生 > >Hard F2. Representative Sampling 虛樹+樹形DP

Hard F2. Representative Sampling 虛樹+樹形DP

Description 給你n個字串,讓你選出k個字串使它們的價值最大,定義一個集合的價值為兩兩最長公共字首和。

Sample Input 3 2 aba bzd abq

Sample Output 2

你可以先建出一個字典樹,建出一個虛樹。 然後節點不會超過2n個,直接樹形DP即可。 設f[i][j]為以i為子樹選了j個。 然後直接轉移,你可能需要將虛樹的節點離散化一下,為此狂WA不止。

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
const ULL P = 131;
int _min(int x, int y) {return x < y ? x : y;}
LL _max(LL x, LL y) {return x > y ? x : y;}
int read() {
	int s = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s * f;
}

struct tnode {
	int son[26], s, dep;
} t[2010 * 510]; int cnt;
struct edge {
	int x, y, next;
} e[2010 * 2]; int len, last[2010 * 510];
int hh, id, uu[2010], gg[2010 * 510]; LL f[2010 * 2][2010];
int fa[2010 * 510], son[2010 * 510], dep[2010 * 510], tot[2010 * 510], top[2010 * 510];
int K, sta[2010 * 2];
char ss[510];

void ins(int x, int y) {
	e[++len].x = x, e[len].y = y;
	e[len].next = last[x], last[x] = len;
}

void ins(int len) {
	int x = 0;
	for(int i = 1; i <= len; i++) {
		int y = ss[i] - 'a';
		if(!t[x].son[y]) t[x].son[y] = ++cnt;
		x = t[x].son[y];
	} t[x].s++;
}

void pre_tree_node(int x) {
	tot[x] = 1;
	if(t[x].s) uu[++hh] = x;
	for(int i = 0; i < 26; i++) if(t[x].son[i]){
		int y = t[x].son[i];
		dep[y] = dep[x] + 1;
		fa[y] = x;
		pre_tree_node(y);
		tot[x] += tot[y];
		if(tot[son[x]] < tot[y]) son[x] = y;
	}
}

void pre_tree_edge(int x, int tp) {
	top[x] = tp;
	if(son[x]) pre_tree_edge(son[x], tp);
	for(int i = 0; i < 26; i++) if(t[x].son[i]){
		int y = t[x].son[i];
		if(y != son[x]) pre_tree_edge(y, y);
	}
}

int LCA(int x, int y) {
	int tx = top[x], ty = top[y];
	while(tx != ty) {
		if(dep[tx] > dep[ty]) swap(tx, ty), swap(x, y);
		y = fa[ty], ty = top[y];
	} if(dep[x] > dep[y]) swap(x, y);
	return x;
}

void treedp(int x) {
	gg[x] = ++id;
	tot[x] = _min(K, t[x].s);
	for(int i = 1; i <= tot[x]; i++) f[gg[x]][i] = (LL)dep[x] * i * (i - 1) / 2;
	for(int k = last[x]; k; k = e[k].next) {
		int y = e[k].y;
		treedp(y);
		for(int i = _min(tot[x], K); i >= 0; i--) {
			for(int j = _min(K - i, tot[y]); j >= 0; j--) {
				f[gg[x]][i + j] = _max(f[gg[x]][i + j], f[gg[x]][i] + f[gg[y]][j] + (LL)i * j * dep[x]);
			}
		} tot[x] += tot[y];
	}
}

int main() {
	int n = read(); K = read();
	for(int i = 1; i <= n; i++) {
		scanf("%s", ss + 1);
		ins(strlen(ss + 1));
	} pre_tree_node(0);
	pre_tree_edge(0, 0);
	int tp = 0; sta[++tp] = uu[1];
	for(int i = 2; i <= hh; i++) {
		int lca = LCA(sta[tp], uu[i]);
		if(lca == sta[tp]) {
			if(sta[tp] != uu[i]) sta[++tp] = uu[i];
			continue;
		} while(tp > 1 && dep[sta[tp - 1]] >= dep[lca]) ins(sta[tp - 1], sta[tp]), --tp;
		if(sta[tp] != lca) ins(lca, sta[tp]), sta[tp] = lca;
		sta[++tp] = uu[i];
	} int rt = sta[1];
	while(tp > 1) ins(sta[tp - 1], sta[tp]), --tp;
	memset(tot, 0, sizeof(tot));
	treedp(rt);
	printf("%I64d\n", f[gg[rt]][K]);
	return 0;
}