1. 程式人生 > >[JZOJ5641] 林克卡特樹【樹形DP】【凸優化】

[JZOJ5641] 林克卡特樹【樹形DP】【凸優化】

Description

給定一棵n個節點的樹,邊有邊權(可能為負)。

你需要刪掉恰好K條邊,再連上恰好K條邊權為0的邊,並保證連完邊後這還是一棵樹,求這棵樹的最大的最長路長度。
K < n 300000 K<n\leq 300000

1 0 6
|邊權|\leq 10^6

Solution

轉化模型

刪K條邊再加K條邊,那麼對於新樹上的一條路徑,一定可以用原樹上K+1條點不相交的鏈來表示它(注意一個單點也可以看做一條鏈,因此路徑長不夠的時候可以用單點來補)

那麼問題就轉化為在原樹上選恰好K+1條鏈,使總長最大。

注意到如果我們直接設 f [ i

] [ j ] [ 0 / 1 / 2 ] f[i][j][0/1/2] 為當前做完以 i i 為根的子樹,選了j條鏈,i這個點接的鏈的情況(沒有/被子樹中不超過一條連結上/被子樹中不超過兩條連結上(即它不能再接出父親))

顯然這樣狀態數是 3 n 2 3*n^2 的,不能滿足要求

考慮優化:

感受一下,如果把K作為橫軸,x=K下的最優答案作為縱軸,那平面上就有了n個點,這n個點構成了一個凸包。

也就是說,這是單峰的,並且相鄰點連線斜率不增。

可以反證,對於選鏈的情況進行討論,再分析一下增量,發現更優的選擇一定會在更早選,具體不再贅述。

我們發現,如果沒有鏈數限制,我們可以在 O ( n ) O(n) 的時間內求出整體最優解(即凸包最高點的值,也可以求出它用了多少條鏈)

假如我們將整個凸包整體旋轉某個角度(橫座標不變),使我們需要的x=K的點成為整體最高點,那我們就可以快速算出來了。

整體旋轉某個角度,等同於用一條過原點的直線去切這個凸包
如下圖

在這裡插入圖片描述
實際上,就是對於每一個點,將縱座標減去橫座標*這條的直線的斜率

考慮這樣做在原題目中的體現,相當於每選一條鏈還需要另外支付一個代價(斜率),求最優解。

此時我們可以二分這個代價(斜率),求出最高點,看最高點的橫座標(選的鏈數)<K還是>K,來調整二分的斜率。

直到最後,我們二分出了一個恰當的斜率,使得最高點的橫座標為K,那麼輸出縱座標+K*斜率即可。

注意到對於這一題,相鄰兩個點的橫座標差一定為1,且最優值都是整數,那麼我們也只需要在整數中二分斜率。

有一種特殊情況,就是連著很多個點斜率相同,K又不在兩端,那麼此時算出的最優解橫座標不一定是=k的,但我們發現,這些點縱座標相等,加上橫座標*斜率以後,仍然能得到橫座標=k時的解,因此是沒有問題的。

這樣總複雜度就是 O ( n log M A X V ) O(n\log MAXV)

Code

#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <cstring>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define N 300005
#define LL long long
using namespace std;
LL f[N][2],g[N][2],h[N][2],c[2],d[2],e[2],pr[2*N],mid;
int dt[2*N],nt[2*N],fs[N],n,m,l;
void dp(int k,int fa)
{
	h[k][0]=f[k][0]=-mid,h[k][1]=f[k][1]=1;
	g[k][0]=0,g[k][1]=0;
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=fa) 
		{
			dp(p,k);
			fo(j,0,1) c[j]=f[k][j],d[j]=g[k][j],e[j]=h[k][j];
			if(d[0]<f[k][0]+f[p][0]+pr[i]+mid) 
			{
				d[0]=f[k][0]+f[p][0]+pr[i]+mid;
				d[1]=f[k][1]+f[p][1]-1;
			}
			if(c[0]<h[k][0]+f[p][0]+pr[i]+mid)
			{
				c[0]=h[k][0]+f[p][0]+pr[i]+mid;
				c[1]=h[k][1]+f[p][1]-1;
			}
			if(c[0]<f[k][0]+g[p][0]) c[0]=f[k][0]+g[p][0],c[1]=f[k][1]+g[p][1];
			if(d[0]<g[k][0]+g[p][0]) d[0]=g[k][0]+g[p][0],d[1]=g[k][1]+g[p][1];
			if(e[0]<h[k][0]+g[p][0]) e[0]=h[k][0]+g[p][0],e[1]=h[k][1]+g[p][1];
			fo(j,0,1) f[k][j]=c[j],g[k][j]=d[j],h[k][j]=e[j];
			if(f[k][0]<h[k][0]) f[k][0]=h[k][0],f[k][1]=h[k][1];
			if(g[k][0]<f[k][0]) g[k][0]=f[k][0],g[k][1]=f[k][1];
		}
	}
	if(f[k][0]<h[k][0]) f[k][0]=h[k][0],f[k][1]=h[k][1];
	if(g[k][0]<f[k][0]) g[k][0]=f[k][0],g[k][1]=f[k][1];
}
void link(int x,int y,int z)
{
	nt[++m]=fs[x];
	dt[fs[x]=m]=y;
	pr[m]=z;
}
int main()
{
	cin>>n>>l;
	fo(i,1,n-1) 
	{
		int x,y,z;
		scanf("%d%d%d",&x,&y,&z);
		link(x,y,z),link(y,x,z);
	}
	l++;
	LL x=-1e6,y=1e8;
	while(x<y)
	{
		mid=(x+y)/2;
		fo(i,1,n) f[i][0]=g[i][0]=h[i][0]=-1e9;
		dp(1,0);
		if(g[1][1]==l) 
		{
			printf("%lld\n",g[1][0]+mid*(LL)l);
			return 0;
		}
		if(g[1][1]<l) y=mid-1;
		else x=mid+1;
	}
	printf("%lld\n",g[1][0]+mid*(LL)l);
}