1. 程式人生 > >洛谷2619/bzoj2654 Tree(凸優化+MST)

洛谷2619/bzoj2654 Tree(凸優化+MST)

bzoj的資料是真的水。。
qwq
由於本人還有很多東西不是很理解
qwq
所以這裡只寫一個正確的做法。

首先,我們會發現,對於你選擇白色邊的數目,隨著數目的上漲,斜率是單調升高的。

那麼這時候我們就可以考慮凸優化,也就是\(wqs\)二分來滿足題目中所述的正好\(k\)條邊的限制。

我們\(erf\)一個\(mid\),然後讓每一個白邊的權值都加上\(mid\),然後跑\(MST\),看最後的選的白色邊數,是否是大於等於\(k\)的,如果是,就調大\(l\),否則調小\(r\)

由於最小生成樹選擇邊的時候可能有一些玄學的錯誤,所以我們在\(sort\)的時候,對於權值相等的邊,我們優先選擇白邊。

那麼通過\(erf\),之後,我們就能得到一個上界,也就是在當前的偏移量下,我們最多的選和1相連的邊的個數。

根據\(clj\)的官方題解,這裡有兩個引理

對於一個圖,如果存在一個最小生成樹,它的白邊的數量是\(x\),那麼就稱\(x\)是最小合法白邊數。所有的最小合法白邊數形成一個區間\([l,r]\)
(因為題目保證有解,所以我們只需要找到最小的\(r\)即可)

那麼經過這個\(erf\),我們就能得到一個最小的\(r\)

那麼我們應該怎麼求整個\(MST\)的權值呢,我們會發現,對於權值相等的白邊和黑邊,由於題目保證有解,所以一定是會存在相互替代的關係的。
那我們可以按照之前的最小生成樹的策略選白邊,將其記為\(val\)

,最後輸出\(val-k*ans\)\(ans\)表示最後的\(mid\)
為什麼是\(k\)而不是具體的選的邊的數目呢?

因為題目要求正好選擇\(k\)條,而我們這裡實際上是把多餘的白邊都直接視為黑邊來做了
qwqwq
那麼這個題就能解決了
qwqwqwqwq
但是我根據CF125E那個題,有一個比較特殊的做法,但是套到這個這個題,我並不是很理解。qwq
這個坑還是之後再填吧

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<vector>
#include<map>
#include<vector>
#define mk make_pair
#define pb push_back
#define ll long long
#define int long long
using namespace std;
inline int read()
{
   int x=0,f=1;char ch=getchar();
   while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
   while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
   return x*f;
} 
const int maxn = 4e5+1e2;
struct Edge{
    int u,v,w;
    int col;
}; 
Edge e[maxn];
int n,m;
int ans;
int l=-200,r=200;
int fa[maxn];
int find(int x)
{
    if (fa[x]!=x) fa[x]=find(fa[x]);
    return fa[x];
}
int k;
bool cmp(Edge a,Edge b)
{
    if (a.w==b.w) return a.col<b.col;
    return a.w<b.w;
} 
int solve()
{
    sort(e+1,e+1+m,cmp);
    int tot=0;
    for (int i=1;i<=m;i++)
    {
        int f1 = find(e[i].u);
        int f2 = find(e[i].v);
        if (f1==f2) continue;
        //if(tot==k && e[i].col==0) continue;
        if (e[i].col==0) ++tot;
        fa[f1]=fa[f2];
    }
    return tot;
}
signed main()
{
  n=read(),m=read();k=read();
  for (int i=1;i<=m;i++)
  {
    e[i].u=read()+1;
    e[i].v=read()+1;
    e[i].w=read();
    e[i].col=read();
  }
  while(l<=r)
  {
     int mid = (l+r) >> 1;
     for (int i=1;i<=n;i++) fa[i]=i;
     for (int i=1;i<=m;i++)
     {
        if (e[i].col==0) e[i].w+=mid; 
     }
     int tmp = solve();
     if (tmp<k)
     {
        r=mid-1;
     }
     else l=mid+1,ans=mid;
     for (int i=1;i<=m;i++) 
     {
        if (e[i].col==0) e[i].w-=mid;
     }
  }
  for (int i=1;i<=n;i++) fa[i]=i;
  for (int i=1;i<=m;i++)
  if (e[i].col==0) e[i].w+=ans;
  sort(e+1,e+1+m,cmp);
  int tot=0,val=0;
  for (int i=1;i<=m;i++)
 {
        int f1 = find(e[i].u);
        int f2 = find(e[i].v);
        if (f1==f2) continue;
        if (e[i].col==0) ++tot;
        fa[f1]=fa[f2];
        val+=e[i].w;
  }
  cout<<val-k*ans;
  return 0;
}