洛谷2619/bzoj2654 Tree(凸優化+MST)
阿新 • • 發佈:2018-12-31
bzoj的資料是真的水。。
qwq
由於本人還有很多東西不是很理解
qwq
所以這裡只寫一個正確的做法。
首先,我們會發現,對於你選擇白色邊的數目,隨著數目的上漲,斜率是單調升高的。
那麼這時候我們就可以考慮凸優化,也就是\(wqs\)二分來滿足題目中所述的正好\(k\)條邊的限制。
我們\(erf\)一個\(mid\),然後讓每一個白邊的權值都加上\(mid\),然後跑\(MST\),看最後的選的白色邊數,是否是大於等於\(k\)的,如果是,就調大\(l\),否則調小\(r\)。
由於最小生成樹選擇邊的時候可能有一些玄學的錯誤,所以我們在\(sort\)的時候,對於權值相等的邊,我們優先選擇白邊。
那麼通過\(erf\),之後,我們就能得到一個上界,也就是在當前的偏移量下,我們最多的選和1相連的邊的個數。
根據\(clj\)的官方題解,這裡有兩個引理
對於一個圖,如果存在一個最小生成樹,它的白邊的數量是\(x\),那麼就稱\(x\)是最小合法白邊數。所有的最小合法白邊數形成一個區間\([l,r]\)
(因為題目保證有解,所以我們只需要找到最小的\(r\)即可)
那麼經過這個\(erf\),我們就能得到一個最小的\(r\)
那麼我們應該怎麼求整個\(MST\)的權值呢,我們會發現,對於權值相等的白邊和黑邊,由於題目保證有解,所以一定是會存在相互替代的關係的。
那我們可以按照之前的最小生成樹的策略選白邊,將其記為\(val\)
為什麼是\(k\)而不是具體的選的邊的數目呢?
因為題目要求正好選擇\(k\)條,而我們這裡實際上是把多餘的白邊都直接視為黑邊來做了
qwqwq
那麼這個題就能解決了
qwqwqwqwq
但是我根據CF125E那個題,有一個比較特殊的做法,但是套到這個這個題,我並不是很理解。qwq
這個坑還是之後再填吧
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #include<queue> #include<vector> #include<map> #include<vector> #define mk make_pair #define pb push_back #define ll long long #define int long long using namespace std; inline int read() { int x=0,f=1;char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();} while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } const int maxn = 4e5+1e2; struct Edge{ int u,v,w; int col; }; Edge e[maxn]; int n,m; int ans; int l=-200,r=200; int fa[maxn]; int find(int x) { if (fa[x]!=x) fa[x]=find(fa[x]); return fa[x]; } int k; bool cmp(Edge a,Edge b) { if (a.w==b.w) return a.col<b.col; return a.w<b.w; } int solve() { sort(e+1,e+1+m,cmp); int tot=0; for (int i=1;i<=m;i++) { int f1 = find(e[i].u); int f2 = find(e[i].v); if (f1==f2) continue; //if(tot==k && e[i].col==0) continue; if (e[i].col==0) ++tot; fa[f1]=fa[f2]; } return tot; } signed main() { n=read(),m=read();k=read(); for (int i=1;i<=m;i++) { e[i].u=read()+1; e[i].v=read()+1; e[i].w=read(); e[i].col=read(); } while(l<=r) { int mid = (l+r) >> 1; for (int i=1;i<=n;i++) fa[i]=i; for (int i=1;i<=m;i++) { if (e[i].col==0) e[i].w+=mid; } int tmp = solve(); if (tmp<k) { r=mid-1; } else l=mid+1,ans=mid; for (int i=1;i<=m;i++) { if (e[i].col==0) e[i].w-=mid; } } for (int i=1;i<=n;i++) fa[i]=i; for (int i=1;i<=m;i++) if (e[i].col==0) e[i].w+=ans; sort(e+1,e+1+m,cmp); int tot=0,val=0; for (int i=1;i<=m;i++) { int f1 = find(e[i].u); int f2 = find(e[i].v); if (f1==f2) continue; if (e[i].col==0) ++tot; fa[f1]=fa[f2]; val+=e[i].w; } cout<<val-k*ans; return 0; }