2018-2019 ACM-ICPC, Asia Xuzhou Regional Contest G. Rikka with Intersections of Paths(樹上差分+LCA+容斥)
阿新 • • 發佈:2018-12-19
題目連結:http://codeforces.com/gym/102012/problem/G
題目大意:有一棵n個結點的樹,現在給出m條樹上的路徑。現在要從這m條路徑中選出k條路徑,使得這k條路徑至少有一個公共交點,問你總共有多少種方案數。
題目思路:(今年徐州現場的銀牌題,我們隊肝到最後也沒能肝出來,錯失了銀牌。。。QAQ,當時忘了一個重要的性質,導致正思路都錯了。還是太菜了)
感慨一下,繼續分析題目。
解決這個題,需要用到一個重要的性質:一個樹上任意兩條路徑如果有交點的話,那麼這些交點中肯定有一個為兩條路徑中的一條路徑兩端點的lca。
有了這個性質的話,我們可以對通過列舉路徑的交點來求答案。
對於每個節點,我們假設通過這個節點的路徑有M條,以這個點為LCA且通過這個節點的路徑有N條。
那麼在這個節點對答案的貢獻為:。這個式子計算出來的是,從通過這個節點的路徑中選出k條路徑,且至少有一條路徑的LCA為這個節點的方案數,這樣選的話就不會出現重複選的情況了,因為至少有一條路徑以該節點為LCA,在以其他點為交點的時候就不會重複計算了。
而通過某個結點的路徑數我們可以通過樹上差分計算,假設通過u這個節點的路徑為sum[u]。那麼在更新路徑[u,v]的時候,我們就令sum[u]++,sum[v]++,sum[lca(u,v)]--,sum[fa[lca(u,v)]]--。接著再用dfs一遍即可。
具體實現看程式碼:
#include <bits/stdc++.h> #define fi first #define se second #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define pb push_back #define MP make_pair #define lowbit(x) x&-x #define clr(a) memset(a,0,sizeof(a)) #define _INF(a) memset(a,0x3f,sizeof(a)) #define FIN freopen("in.txt","r",stdin) #define IOS ios::sync_with_stdio(false) #define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<int, int>pii; typedef pair<ll, ll>pll; const int MX = 3e5 + 5; const int mod = 1e9 + 7; int n, m, k; struct edge {int v, w, nxt;} E[MX << 1]; int head[MX], tot; int dep[MX], ST[MX][20]; void add_edge(int u, int v) { E[tot].v = v; E[tot].nxt = head[u]; head[u] = tot++; } void dfs(int u, int d, int fa) { dep[u] = d; ST[u][0] = fa; for (int i = head[u]; ~i; i = E[i].nxt) { int v = E[i].v; if (v == fa) continue; dfs(v, d + 1, u); } } void pre_solve() { dfs(1, 0, 1); for (int i = 1; i < 20; i++) { for (int j = 1; j <= n; j++) { ST[j][i] = ST[ST[j][i - 1]][i - 1]; } } } int LCA(int u, int v) { while (dep[u] != dep[v]) { if (dep[u] < dep[v]) swap(u, v); int d = dep[u] - dep[v]; for (int i = 0; i < 20; i++) if (d >> i & 1)u = ST[u][i]; } if (u == v) return u; for (int i = 19; i >= 0; i--) { if (ST[u][i] != ST[v][i]) { u = ST[u][i]; v = ST[v][i]; } } return ST[u][0]; } int sum[MX], lca_sum[MX]; void solve(int u, int fa) { for (int i = head[u]; ~i; i = E[i].nxt) { int v = E[i].v; if (v == fa) continue; solve(v, u); sum[u] += sum[v]; } } ll f[MX], inv[MX]; ll qpow(ll a, ll b) { ll res = 1; while (b) { if (b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1; } return res; } void init() { f[1] = 1; for (int i = 2; i < MX; i++) f[i] = (f[i - 1] * i) % mod; inv[MX - 1] = qpow(f[MX - 1], mod - 2); for (int i = MX - 2; i >= 1; i--) inv[i] = (inv[i + 1] * (i + 1)) % mod; } ll C(int n, int m) { if (n < 0 || m < 0 || m > n) return 0; if (m == 0 || m == n) return 1; return f[n] * inv[n - m] % mod * inv[m] % mod; } int main() { // FIN; init(); int T; cin >> T; while (T--) { scanf("%d%d%d", &n, &m, &k); for (int i = 1; i <= n; i++) head[i] = -1; tot = 0; for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); add_edge(u, v); add_edge(v, u); } pre_solve(); for (int i = 1; i <= m; i++) { int u, v; scanf("%d%d", &u, &v); int lca = LCA(u, v); lca_sum[lca]++; sum[u]++; sum[v]++; sum[lca]--; if (lca != 1) sum[ST[lca][0]]--; } solve(1, 0); ll ans = 0; for (int i = 1; i <= n; i++) ans = (ans % mod + ((C(sum[i], k) - C(sum[i] - lca_sum[i], k)) % mod + mod) % mod) % mod; printf("%lld\n", ans); } return 0; }