1. 程式人生 > >LibreOJ #2478.「九省聯考 2018」林克卡特樹 樹形dp+帶權二分

LibreOJ #2478.「九省聯考 2018」林克卡特樹 樹形dp+帶權二分

題意

給出一棵n個節點的樹和k,邊有邊權,要求先從樹中選k條邊,然後把這k條邊刪掉,再加入k條邊權為0的邊,滿足操作完後的圖仍然是一棵樹。問新樹的帶權直徑最大是多少。
n,k3105

分析

不難發現我們要求的就是在樹中選出k+1條不相交的鏈使得其權值和最大。
當k比較小的時候,我們可以樹形dp,設f[i,j,0/1/2]表示以i為根的樹中選了j條鏈,節點i的度數為0/1/2時的最大權值。
這樣做的複雜度是O(nk)的。
當k變大之後,我們就可以用帶權二分來做。
具體來說就是二分一個權值mid,然後給每條路徑的權值加上mid,再去除k的限制後進行dp。若選出的鏈的數量不小於k+1,則把mid減小,否則把mid增大。
當某一刻選出的鏈數量恰好為k+1,則當前權值-mid*(k+1)就是答案了。

程式碼

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>

typedef long long LL;

const int N=300005;
const LL inf=(LL)1e12;

int n,k,cnt,last[N];
LL mid;
struct data
{
    LL x,y;
    bool operator > (const data &d) const
{return x>d.x||x==d.x&&y>d.y;} bool operator < (const data &d) const {return x<d.x||x==d.x&&y<d.y;} data operator + (const data &d) const {return (data){x+d.x,y+d.y};} }f[N][3],tmp[3]; struct edge{int to,next,w;}e[N*2]; int read() { int x=0,f=1;char
ch=getchar(); while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } void addedge(int u,int v,int w) { e[++cnt].to=v;e[cnt].w=w;e[cnt].next=last[u];last[u]=cnt; e[++cnt].to=u;e[cnt].w=w;e[cnt].next=last[v];last[v]=cnt; } void dp(int x,int fa) { f[x][0]=(data){0,0};f[x][1]=(data){-inf,0};f[x][2]=(data){mid,1}; for (int i=last[x];i;i=e[i].next) { int to=e[i].to; if (to==fa) continue; dp(to,x); data u;tmp[0]=tmp[1]=tmp[2]=(data){-inf,0}; u=f[x][0]+f[to][0];tmp[0]=std::max(tmp[0],u); u.x+=(LL)e[i].w+mid;u.y++;tmp[1]=std::max(tmp[1],u); u=f[x][0]+f[to][1];tmp[0]=std::max(tmp[0],u); u.x+=(LL)e[i].w;tmp[1]=std::max(tmp[1],u); u=f[x][0]+f[to][2];tmp[0]=std::max(tmp[0],u); u=f[x][1]+f[to][0];tmp[1]=std::max(tmp[1],u); u.x+=(LL)e[i].w;tmp[2]=std::max(tmp[1],u); u=f[x][1]+f[to][1];tmp[1]=std::max(tmp[1],u); u.x+=e[i].w-mid;u.y--;tmp[2]=std::max(tmp[2],u); u=f[x][1]+f[to][2];tmp[1]=std::max(tmp[1],u); u=f[x][2]+f[to][0];tmp[2]=std::max(tmp[2],u); u=f[x][2]+f[to][1];tmp[2]=std::max(tmp[2],u); u=f[x][2]+f[to][2];tmp[2]=std::max(tmp[2],u); f[x][0]=tmp[0];f[x][1]=tmp[1];f[x][2]=tmp[2]; } } int main() { n=read();k=read(); for (int i=1;i<n;i++) { int x=read(),y=read(),z=read(); addedge(x,y,z); } LL l=-inf,r=inf,ans; while (l<=r) { mid=(l+r)/2; dp(1,0); data u=std::max(f[1][0],std::max(f[1][1],f[1][2])); if (u.y>=k+1) ans=u.x,r=mid-1; else l=mid+1; } printf("%lld",ans-(LL)(r+1)*(k+1)); return 0; }