1. 程式人生 > >HDU 6338 2018HDU多校賽 第四場 Depth-First Search(組合數學+平衡樹/pbds)

HDU 6338 2018HDU多校賽 第四場 Depth-First Search(組合數學+平衡樹/pbds)

大致題意:給你一個dfs序列B和一棵樹,現在讓你在這個樹上隨機選擇一個點,然後按照隨機的dfs順序走。問你最後能走出幾個dfs序列,是得該dfs序列字典序小於給定的dfs序B。

首先,我們考慮一棵樹有根樹他的dfs序有多少種。我們可以這麼考慮,對於任意點x,我都可以任意的向它的所有兒子走去,那麼就會對應 \large son[x]! 種方法。我們注意到,除了根之外,所有的點的兒子的數目等於其度數減一,那麼,我們便可以得出一棵有根樹的dfs序列為:\large \prod (deg[x]-1)!*deg[root]。進一步,我們可以令\large res=\prod(deg[x]-1)! ,那麼對於不同的根,其對應樹的方案數就是 res*deg[root],也即res就是所謂的公共部分。

接著,我們來考慮這道題目。由於題目要求是字典序比給定的要小,而且是dfs序,所以我麼考慮按照它給定的順序進行dfs,逐位計算種類數。初始根的時候,我們先利用上面的公式,計算所有以編號小於B[0]的點為根的方案。然後開始dfs,當我們走到樹上的x節點,序列的第i位的時候,在x的所有的可選兒子中,查詢有多少個的編號小於B[i]。不妨設此時恰好有t個可選兒子的編號小於B[i],那麼這個點的貢獻就是\large \frac{t*res*(deg[x]-1)!}{deg[x]!}

,即字典序小於B[i]我可以在t箇中選擇一個,選完之後,相當於x的所有兒子中少了一個可以任意選擇位置的點,因此要少去一些方案。計算完貢獻,我們繼續順著給定的dfs序列往下走,同時要把每次按照這個dfs序經過的所有點標記為不可選,對應父親的可用節點數目要減一,這裡可以對應deg[fa]減一,因為相當於,我已經確定這個點的位置了。res的值也要改變,原因同計算貢獻的時候。

再整理一下,就是遇到一個點首先計算貢獻,然後順著往下走的同時,把這個點在它父親的可選點中刪除,對應的res也要少一個產生貢獻的自由點。由於這是dfs,所以經過某一個點之後,後面可能再次經過,然後第二次經過的時候,我們還是需要對於一個定值,看有多少可行的點比它小,這個過程如果沒有高效的方法,複雜度將會比較大。

由此,我們就需要一個,支援插入、刪除和查詢有多少個點比某一個定值小的資料結構。用一個Treap平衡樹即可,需要一個rank操作。另外,今天還新學到一個pbds庫,裡面有可以用到的紅黑樹(red-black tree),直接就可以支援插入、刪除和查詢rank的操作,可以省去大量的程式碼,值得正式比賽用。具體見程式碼:

#include<bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define mod 1000000007
#define LL long long
#define N 1000010
using namespace std;

LL inv[N],fac[N],ans,res;
int d[N],b[N],f[N],n,m;
vector<int> g[N];
int rt[N],sz;
bool flag;

struct Treap
{
    #define ls T[i].ch[0]
    #define rs T[i].ch[1]

    struct treap{int ch[2],sz,val,cnt,fix;} T[N];

    void up(int i)
    {
        T[i].sz=T[i].cnt+T[ls].sz+T[rs].sz;
    }

    void Rotate(int &x,bool d)
    {
        int y=T[x].ch[d];
        T[x].ch[d]=T[y].ch[d^1];
        T[y].ch[d^1]=x; up(x),up(x=y);
    }

    void ins(int &i,int x)
    {
        if (!i)
        {
            i=++sz;T[i].fix=rand();
            T[i].sz=T[i].cnt=1;
            T[i].val=x;ls=rs=0;
            return;
        }
        T[i].sz++;
        if (x==T[i].val) {T[i].cnt++;return;}
        int d=x>T[i].val; ins(T[i].ch[d],x);
        if (T[T[i].ch[d]].fix<T[i].fix) Rotate(i,d);
    }

    void del(int &i,int x)
    {
        if (!i) return;
        if (T[i].val==x)
        {
            if (T[i].cnt>1){T[i].cnt--,T[i].sz--;return;}
            int d=T[ls].fix>T[rs].fix;
            if (ls==0||rs==0) i=ls+rs;
            else Rotate(i,d),del(i,x);
        } else T[i].sz--,del(T[i].ch[x>T[i].val],x);
    }

    int rank(int i,int x)
    {
        if (!i) return 0;
        if (T[i].val>x) return rank(ls,x);
        if (T[i].val==x) return T[ls].sz+1;
        return rank(rs,x)+T[ls].sz+T[i].cnt;
    }

} treap;


void init()
{
    fac[1]=fac[0]=1;
    inv[1]=inv[0]=1;
    for(int i=2;i<N;i++)
    {
        fac[i]=fac[i-1]*i%mod;
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    }
    for(int i=2;i<N;i++)
        inv[i]=inv[i-1]*inv[i]%mod;
}

void dfs(int x,int fa)
{
    f[x]=fa;
    if (fa) treap.ins(rt[fa],x);
    for(int i=0;i<g[x].size();i++)
    {
        int y=g[x][i];
        if (y==fa) continue;
        dfs(y,x); d[y]--;
    }
}

void dfs(int x)
{
    if (m>n||flag||!x) return;
    if (d[x]!=0)
    {
        int t=treap.rank(rt[x],b[m+1]-1);                    //查詢有多少點小於b[m+1]
        ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod;        //計算貢獻
        if (f[b[m+1]]!=x) {flag=1;return;} m++;
        res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--;       //改變res,同時x要少一個可選點
        treap.del(rt[x],b[m]); dfs(b[m]);                //把b[m]從x的可選點中刪除
    } else dfs(f[x]);
}

int main()
{
    init();
    IO;int T;cin>>T;
    while(T--)
    {
        cin>>n; res=1,ans=sz=0;
        for(int i=1;i<=n;i++)
            cin>>b[i],d[i]=rt[i]=0,g[i].clear();
        for(int i=1;i<n;i++)
        {
            int x,y;
            cin>>x>>y;
            g[x].push_back(y);
            g[y].push_back(x);
            d[x]++,d[y]++;
        }
        for(int i=1;i<=n;i++)
            res=res*fac[d[i]-1]%mod;
        for(int i=1;i<b[1];i++)
            ans=(ans+res*d[i]%mod)%mod;
        res=res*d[b[1]]%mod; m=1; flag=0;
        dfs(b[1],0); dfs(b[1]);
        cout<<ans<<endl;
    }
    return 0;
}

然後我們還有pbds庫版本的程式碼,這個更加的簡潔,適合比賽用,但是速度可能就會慢一點點。

#include<bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define mod 1000000007
#define LL long long
#define N 1000010
using namespace std;
using namespace __gnu_pbds;
tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> rbt[N];
LL inv[N],fac[N],ans,res;
int d[N],b[N],f[N],n,m;
vector<int> g[N];

bool flag;

void init()
{
    fac[1]=fac[0]=1;
    inv[1]=inv[0]=1;
    for(int i=2;i<N;i++)
    {
        fac[i]=fac[i-1]*i%mod;
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    }
    for(int i=2;i<N;i++)
        inv[i]=inv[i-1]*inv[i]%mod;
}

void dfs(int x,int fa)
{
    f[x]=fa;
    rbt[fa].insert(x);
    for(int i=0;i<g[x].size();i++)
        if (g[x][i]!=fa) dfs(g[x][i],x);
}

void dfs(int x)
{
    if (m>n||flag||!x) return;
    if (d[x]!=0)
    {
        int t=rbt[x].order_of_key(b[m+1]);
        ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod;
        if (f[b[m+1]]!=x) {flag=1;return;} m++;
        res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--;
        rbt[x].erase(b[m]); dfs(b[m]);
    } else dfs(f[x]);
}

int main()
{
    init();
    IO;int T;cin>>T;
    while(T--)
    {
        cin>>n; res=1,ans=0;
        for(int i=1;i<=n;i++)
            cin>>b[i],d[i]=0,rbt[i].clear(),g[i].clear();
        for(int i=1;i<n;i++)
        {
            int x,y;
            cin>>x>>y;
            g[x].push_back(y);
            g[y].push_back(x);
            d[x]++,d[y]++;
        }
        for(int i=1;i<=n;i++)
            res=res*fac[d[i]-1]%mod;
        for(int i=1;i<b[1];i++)
            ans=(ans+res*(LL)d[i]%mod)%mod;
        for(int i=1;i<=n;i++)
            if (i!=b[1]) d[i]--;
        res=res*d[b[1]]%mod; m=1; flag=0;
        dfs(b[1],0); dfs(b[1]);
        cout<<ans<<endl;
    }
    return 0;
}