計蒜客 2018ICPC徐州站/gym 102012G Rikka with Intersection(組合計數 + 樹鏈剖分 + 樹狀陣列)
阿新 • • 發佈:2018-12-27
大致題意:給你一個包含n個點的樹和m條路徑。現在讓你從這m條路徑中選擇k條路,使得這k條路徑一定有至少一個公共交點,問選出這k條路徑的方案數是多少。
最樸素的想法就是,每次檢視一個點的貢獻,也就是列舉這個公共點,然後看有多少個路徑經過這個點,組合數求一下即可。但是這個錯誤也是很明顯的,因為有可能同樣一批路徑,會有超過一個的公共點,這樣的話就會重複計算。顯然,如果有多個公共點的話,計算一個即可。那麼我們到底應該如何選取這一個公共點呢?
我們考慮換一種計算方式,考慮每一條路徑的貢獻。對於一條路徑,我們可以把它的貢獻分為兩個部分,一是它與它上面k-1條邊的貢獻,另一部分是它與它下面的k-1條邊的貢獻。因為同樣的方案只需要計算一次,所以我就只計算上面一部分即可,下面一部分在考慮後面的點的貢獻的時候會考慮到。對於上面的一部分,我們考慮這條路徑上的點集與上面任意k-1合法路徑的點集的交集一定包含這條路徑的LCA,也即如果這條路徑與k-1條邊有多個公共點的話,LCA一定是其中之一。根據這條性質,我們只計算LCA即可,因為LCA已經包含了所有的方案不需要再次計算也不會重複。然後考慮完這條邊的方案數之後,還要把這條邊的上的點的貢獻加入樹中。
更具體的說,我們把所有的路徑按照LCA的深度大小排序,從深度小的開始往後處理。每一次首先計算覆蓋了當前路徑的LCA的路徑條數,對這個條數取組合數就是當前路徑的貢獻。然後把當前路徑在樹上的貢獻加入,也即路徑上的點權加一。如此維護所有的路徑並統計貢獻即可。用樹鏈剖分和樹狀陣列維護樹上路徑點權,區間修改單點查詢即可。具體見程式碼:
#include<bits/stdc++.h> #define LL unsigned long long using namespace std; const int mod = 1e9 + 7; const int N = 3e5 + 7; int id[N],top[N],son[N],size[N],fa[N],dep[N],c[N]; struct segment{int l,r,lca;} s[N]; int last[N],g[N<<1],nxt[N<<1]; int fac[N],ifac[N],inv[N]; int num,n,m,e,p; inline void addedge(int x,int y) { g[++e]=y; nxt[e]=last[x]; last[x]=e; } inline void update(int x,int y) { for(int i=x;i<N;i+=i&-i) c[i]+=y; } inline int getsum(int x) { int res=0; for(int i=x;i;i-=i&-i) res+=c[i]; return res; } void dfs1(int u,int d,int f) { son[u]=0; dep[u]=d; size[u]=1; for(int i=last[u];i;i=nxt[i]) if (g[i]!=f) { fa[g[i]]=u; dfs1(g[i],d+1,u); size[u]+=size[g[i]]; if (size[g[i]]>size[son[u]]) son[u]=g[i]; } } void dfs2(int u,int f) { top[u]=f; id[u]=++num; if (son[u]) dfs2(son[u],f); for(int i=last[u];i;i=nxt[i]) if (g[i]!=son[u]&&g[i]!=fa[u]) dfs2(g[i],g[i]); } inline void change(int u, int v) { int tp1 = top[u], tp2 = top[v]; while (tp1 != tp2) { if (dep[tp1] < dep[tp2]){swap(tp1, tp2); swap(u, v);} update(id[tp1],1); update(id[u]+1,-1); u = fa[tp1]; tp1 = top[u]; } if (dep[u] > dep[v]) swap(u, v); update(id[u],1); update(id[v]+1,-1); } inline int LCA(int u, int v) { if (u==v) return u; int tp1 = top[u], tp2 = top[v]; while (tp1 != tp2) { if (dep[tp1] < dep[tp2]){swap(tp1, tp2); swap(u, v);} u = fa[tp1]; tp1 = top[u]; } if (dep[u] > dep[v]) swap(u, v); return u; } inline bool cmp(segment a,segment b) { return dep[a.lca]<dep[b.lca]; } inline void init() { fac[0]=ifac[0]=inv[0]=1; fac[1]=ifac[1]=inv[1]=1; for(int i=2;i<N;i++) { fac[i]=fac[i-1]*(LL)i%mod; inv[i]=(mod-mod/i)*(LL)inv[mod%i]%mod; ifac[i]=ifac[i-1]*(LL)inv[i]%mod; } } inline int C(int n,int m) { if (m>n) return 0; return fac[n]*(LL)ifac[n-m]%mod*ifac[m]%mod; } int main() { int T; init(); scanf("%d",&T); while(T--) { e=num=0; LL ans=0; memset(c,0,sizeof(c)); memset(last,0,sizeof(last)); scanf("%d%d%d",&n,&m,&p); for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); addedge(u,v);addedge(v,u); } dfs1(1,1,1); dfs2(1,1); for(int i=1;i<=m;i++) { int u,v; scanf("%d%d",&u,&v); s[i]=segment{u,v,LCA(u,v)}; } sort(s+1,s+1+m,cmp); for(int i=1;i<=m;i++) { ans=(ans+C(getsum(id[s[i].lca]),p-1))%mod; change(s[i].l,s[i].r); } printf("%lld\n",ans); } }