1. 程式人生 > >【凸優化】【長鏈剖分】【2019冬令營模擬1.8】tree

【凸優化】【長鏈剖分】【2019冬令營模擬1.8】tree

PROMBLEM

給你一棵樹,你需要在樹上選擇恰好 m條點不相交的、長度至少為 k的路徑,使得路徑所覆蓋的點權和儘可能大。求最大點權和。
資料保證有解。

SOLUTION

  • 這是一道綜合的題目,考察凸優化、長鏈剖分、樹形DP、以及關於陣列空間的優化

  • 首先引進凸優化

    • 凸優化就是關於答案可以表示成一個凸函式 f(y),x是題目給出的引數,並且 這個函式的斜率成下降的趨勢(反過來也可以)
      在這裡插入圖片描述
    • 假設我們已知的函式的最大值是f(m’),而我們要求的是f(m),發現m在m’的後面。
    • 我們這個時候可以給這個凸函式加上一個正比例函式,具體就是f '(x)=f (x)+kx
    • 那麼我們會發現對應的x越大,加上的這個值就會越大,相對來說m就會比m’的增量更大,那麼當這個k到一定範圍時,我們的凸函式的最大值就會在m上,通過f '(x)就可以間接求出f(x)了。
    • 從影象上理解,就是將這個凸函式向上(逆時針方向)旋轉
    • 另外,如果這個最大值的m’在m的右邊的話,我們就要將函式向下旋轉,這樣才能保證m能旋到最高處。
    • 對於下凸的函式可以類比解決。
    • 實現我們可以二分這個k,判斷m’在m的哪邊再旋轉。
  • 對於這題來說,我們可以感性地理解,選的鏈的個數越多我們每次的增量就越小,也就是斜率遞減,那麼就成了一個凸函式。那麼我們就可以用這個性質把m給省去

    。於是就變成求最大值了。

  • PS.所有的答案與點權都是整數,所以每次的增量也是整數,二分就不會存在精度問題

  • 我們顯然可以運用樹形DP,設f[x][i]表示以x為根的子樹中,有一條以x為top的長度為i的鏈,符合題目的路徑和這條鏈的總點權和。

    • 但是這個DP是N的三次方的。怎麼優化它?
      1 列舉一個f[x][i],發現可以與它合併的f[y][j]是一段連續的區間,可以用字尾max進行優化。變成N方複雜度
      2 運用長鏈剖分可以優化成O(n)
      • 長鏈剖分是什麼?
      • 顧名思義,樹鏈剖分是以子樹大小做重兒子,長鏈剖分就是以子樹的最大的深度的兒子作重兒子。因為所有的狀態都以深度為關鍵字,我們只需要在每一條“輕邊”的地方轉移整一條“長鏈”,所有的結點只在一個長鏈裡面,所以只會轉移n次
      • 這樣做樹形DP是O(n)的,空間複雜度還是N方的。但是一個點的重兒子遍歷完之後它的資訊被全部轉移到這個點上,這個重兒子的空間就可以釋放掉了。所以實際上可用空間還是O(n)的。
      • 實現上,我用的是DFS序去模擬f陣列,也就是預留陣列空間

總時間複雜度O(nlog(S)),空間複雜度O(n)。
S表示點權絕對值之和。

附上程式碼(第一次打,非常非常醜

#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#define maxn 150005
#define ll long long 
using namespace std;

int n,m,K,i,j,k,x,y,Mxk;
ll v[maxn],l,r,mid,Mx;
ll f[maxn],f0[maxn],g[maxn],g0[maxn],tag[maxn],tag0[maxn];
int em,e[maxn*2],nx[maxn*2],ls[maxn];
int dep[maxn],pson[maxn],mxdep[maxn],tot,dfn[maxn],fa[maxn];

void insert(int x,int y){
	em++; e[em]=y; nx[em]=ls[x]; ls[x]=em; 
	em++; e[em]=x; nx[em]=ls[y]; ls[y]=em;
}

void dfs(int x,int p){
	dep[x]=dep[p]+1; pson[x]=0; mxdep[x]=dep[x];
	for(int i=ls[x];i;i=nx[i]) if (e[i]!=p) {
		dfs(e[i],x);
		if (!pson[x]||mxdep[e[i]]>mxdep[x]) 
			mxdep[x]=mxdep[e[i]],pson[x]=e[i];
	}
}

void dfs2(int x,int p){
	dfn[x]=++tot; 
	if (pson[x]) fa[pson[x]]=fa[x],dfs2(pson[x],x);
	for(int i=ls[x];i;i=nx[i]) if (e[i]!=p&&e[i]!=pson[x])
		fa[e[i]]=e[i],dfs2(e[i],x);
}

void dfs3(int x,int p,ll D){
	ll s=0; int c=0; ll tmp; int tmpc;
	for(int i=ls[x];i;i=nx[i]) if (e[i]!=p)
		dfs3(e[i],x,D),s+=g[e[i]],c+=g0[e[i]];
	
	g[x]=s,g0[x]=c;
	tag[fa[x]]+=v[x]+s-g[pson[x]],tag0[fa[x]]+=c-g0[pson[x]];
	f[dfn[x]]=v[x]+s-tag[fa[x]];
	f0[dfn[x]]=c-tag0[fa[x]];
	if (mxdep[x]>dep[x]&&(f[dfn[x]+1]>f[dfn[x]]||
		f[dfn[x]+1]==f[dfn[x]]&&f0[dfn[x]+1]>f[dfn[x]])) 
			f[dfn[x]]=f[dfn[x]+1],f0[dfn[x]]=f0[dfn[x]+1];
			
	if (mxdep[x]-dep[x]+1>=K){
		tmp=f[dfn[x]+K-1]+tag[fa[x]]+D;
		tmpc=f0[dfn[x]+K-1]+tag0[fa[x]]+1;
		if (tmp>g[x]||tmp==g[x]&&tmpc>=g0[x]) 
			g[x]=tmp,g0[x]=tmpc;
	} 
	
	for(int i=ls[x];i;i=nx[i]) if (e[i]!=p&&e[i]!=pson[x]){
		int y=e[i];
		for(j=0;j<=mxdep[y]-dep[y];j++) if (K-j-1<=mxdep[x]-dep[x]+1) {
			tmp=-g[y]+f[dfn[y]+j]+tag[fa[y]]+f[dfn[x]+max(1,K-j-1)-1]+tag[fa[x]]+D;
			tmpc=-g0[y]+f0[dfn[y]+j]+f0[dfn[x]+max(1,K-j-1)-1]+tag0[fa[y]]+tag0[fa[x]]+1;
			if (tmp>g[x]||tmp==g[x]&&tmpc>g0[x]) g[x]=tmp,g0[x]=tmpc;
		}
		for(j=mxdep[y]-dep[y];j>=0;j--) {
			tmp=s-g[y]+f[dfn[y]+j]+tag[fa[y]]+v[x];
			tmpc=c-g0[y]+f0[dfn[y]+j]+tag0[fa[y]];
			if (tmp>f[dfn[x]+j+1]+tag[fa[x]]||
				tmp==f[dfn[x]+j+1]+tag[fa[x]]&&tmpc>f0[dfn[x]+j+1]+tag0[fa[x]]) 
					f[dfn[x]+j+1]=tmp-tag[fa[x]],f0[dfn[x]+j+1]=tmpc-tag0[fa[x]];
			if ((f[dfn[x]+j]<f[dfn[x]+j+1]||f[dfn[x]+j]==f[dfn[x]+j+1]&&f0[dfn[x]+j]<f0[dfn[x]+j+1]) 
			&&j+1<=mxdep[x]-dep[x])
				f[dfn[x]+j]=f[dfn[x]+j+1],f0[dfn[x]+j]=f0[dfn[x]+j+1];
		}
	}
	
	if (g[x]>Mx) Mx=g[x],Mxk=g0[x];
}

int solve(ll D){
	memset(f,0,sizeof(f));
	memset(g,0,sizeof(g));
	memset(f0,0,sizeof(f0));
	memset(g0,0,sizeof(g0));
	memset(tag,0,sizeof(tag));
	memset(tag0,0,sizeof(tag0));
	Mx=0,Mxk=0;
	dfs3(1,0,D);
	return Mxk;
}

int main(){
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	scanf("%d%d%d",&n,&m,&K);
	ll s=0;
	for(i=1;i<=n;i++) scanf("%lld",&v[i]),s+=abs(v[i]);
	for(i=1;i<n;i++){
		scanf("%d%d",&x,&y);
		insert(x,y);
	}
	dfs(1,0);
	tot=0,fa[1]=1,dfs2(1,0);
	k=solve(0);
	if (k>m) l=-s,r=0; else
	if (k<m) l=0,r=s; else {
		printf("%lld",Mx);
		return 0;
	}
	while (l<r-1){
		mid=(l+r)/2;
		k=solve(mid);
		if (k>m) r=mid; else
		if (k<m) l=mid; else {
 			printf("%lld\n",Mx-mid*m);
			return 0;
		}
	}
	solve(r);
	printf("%lld",Mx-r*m);
}