1. 程式人生 > >計蒜客 2018ICPC徐州站/gym 102012G Rikka with Intersection(組合計數 + 樹鏈剖分 + 樹狀陣列)

計蒜客 2018ICPC徐州站/gym 102012G Rikka with Intersection(組合計數 + 樹鏈剖分 + 樹狀陣列)

 

 

大致題意:給你一個包含n個點的樹和m條路徑。現在讓你從這m條路徑中選擇k條路,使得這k條路徑一定有至少一個公共交點,問選出這k條路徑的方案數是多少。

最樸素的想法就是,每次檢視一個點的貢獻,也就是列舉這個公共點,然後看有多少個路徑經過這個點,組合數求一下即可。但是這個錯誤也是很明顯的,因為有可能同樣一批路徑,會有超過一個的公共點,這樣的話就會重複計算。顯然,如果有多個公共點的話,計算一個即可。那麼我們到底應該如何選取這一個公共點呢?

我們考慮換一種計算方式,考慮每一條路徑的貢獻。對於一條路徑,我們可以把它的貢獻分為兩個部分,一是它與它上面k-1條邊的貢獻,另一部分是它與它下面的k-1條邊的貢獻。因為同樣的方案只需要計算一次,所以我就只計算上面一部分即可,下面一部分在考慮後面的點的貢獻的時候會考慮到。對於上面的一部分,我們考慮這條路徑上的點集與上面任意k-1合法路徑的點集的交集一定包含這條路徑的LCA,也即如果這條路徑與k-1條邊有多個公共點的話,LCA一定是其中之一。根據這條性質,我們只計算LCA即可,因為LCA已經包含了所有的方案不需要再次計算也不會重複。然後考慮完這條邊的方案數之後,還要把這條邊的上的點的貢獻加入樹中。

更具體的說,我們把所有的路徑按照LCA的深度大小排序,從深度小的開始往後處理。每一次首先計算覆蓋了當前路徑的LCA的路徑條數,對這個條數取組合數就是當前路徑的貢獻。然後把當前路徑在樹上的貢獻加入,也即路徑上的點權加一。如此維護所有的路徑並統計貢獻即可。用樹鏈剖分和樹狀陣列維護樹上路徑點權,區間修改單點查詢即可。具體見程式碼:

#include<bits/stdc++.h>
#define LL unsigned long long
using namespace std;

const int mod = 1e9 + 7;
const int N = 3e5 + 7;

int id[N],top[N],son[N],size[N],fa[N],dep[N],c[N];
struct segment{int l,r,lca;} s[N];
int last[N],g[N<<1],nxt[N<<1];
int fac[N],ifac[N],inv[N];
int num,n,m,e,p;

inline void addedge(int x,int y)
{
    g[++e]=y; nxt[e]=last[x]; last[x]=e;
}

inline void update(int x,int y)
{
    for(int i=x;i<N;i+=i&-i)
        c[i]+=y;
}

inline int getsum(int x)
{
    int res=0;
    for(int i=x;i;i-=i&-i)
        res+=c[i];
    return res;
}

void dfs1(int u,int d,int f)
{
    son[u]=0; dep[u]=d; size[u]=1;
    for(int i=last[u];i;i=nxt[i])
        if (g[i]!=f)
        {
            fa[g[i]]=u;
            dfs1(g[i],d+1,u);
            size[u]+=size[g[i]];
            if (size[g[i]]>size[son[u]]) son[u]=g[i];
        }
}

void dfs2(int u,int f)
{
    top[u]=f; id[u]=++num;
    if (son[u]) dfs2(son[u],f);
    for(int i=last[u];i;i=nxt[i])
        if (g[i]!=son[u]&&g[i]!=fa[u]) dfs2(g[i],g[i]);
}

inline void change(int u, int v)
{
    int tp1 = top[u], tp2 = top[v];
    while (tp1 != tp2)
    {
        if (dep[tp1] < dep[tp2]){swap(tp1, tp2); swap(u, v);}
        update(id[tp1],1); update(id[u]+1,-1);
        u = fa[tp1]; tp1 = top[u];
    }
    if (dep[u] > dep[v]) swap(u, v);
    update(id[u],1); update(id[v]+1,-1);
}

inline int LCA(int u, int v)
{
    if (u==v) return u;
    int tp1 = top[u], tp2 = top[v];
    while (tp1 != tp2)
    {
        if (dep[tp1] < dep[tp2]){swap(tp1, tp2); swap(u, v);}
        u = fa[tp1]; tp1 = top[u];
    }
    if (dep[u] > dep[v]) swap(u, v);
    return u;
}

inline bool cmp(segment a,segment b)
{
    return dep[a.lca]<dep[b.lca];
}

inline void init()
{
    fac[0]=ifac[0]=inv[0]=1;
    fac[1]=ifac[1]=inv[1]=1;
    for(int i=2;i<N;i++)
    {
        fac[i]=fac[i-1]*(LL)i%mod;
        inv[i]=(mod-mod/i)*(LL)inv[mod%i]%mod;
        ifac[i]=ifac[i-1]*(LL)inv[i]%mod;
    }

}

inline int C(int n,int m)
{
    if (m>n) return 0;
    return fac[n]*(LL)ifac[n-m]%mod*ifac[m]%mod;
}

int main()
{
    int T;
    init();
    scanf("%d",&T);
    while(T--)
    {
        e=num=0; LL ans=0;
        memset(c,0,sizeof(c));
        memset(last,0,sizeof(last));
        scanf("%d%d%d",&n,&m,&p);
        for(int i=1;i<n;i++)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            addedge(u,v);addedge(v,u);
        }
        dfs1(1,1,1); dfs2(1,1);
        for(int i=1;i<=m;i++)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            s[i]=segment{u,v,LCA(u,v)};
        }
        sort(s+1,s+1+m,cmp);
        for(int i=1;i<=m;i++)
        {
            ans=(ans+C(getsum(id[s[i].lca]),p-1))%mod;
            change(s[i].l,s[i].r);
        }

        printf("%lld\n",ans);
    }
}