HDU 6338 2018HDU多校賽 第四場 Depth-First Search(組合數學+平衡樹/pbds)
大致題意:給你一個dfs序列B和一棵樹,現在讓你在這個樹上隨機選擇一個點,然後按照隨機的dfs順序走。問你最後能走出幾個dfs序列,是得該dfs序列字典序小於給定的dfs序B。
首先,我們考慮一棵樹有根樹他的dfs序有多少種。我們可以這麼考慮,對於任意點x,我都可以任意的向它的所有兒子走去,那麼就會對應 種方法。我們注意到,除了根之外,所有的點的兒子的數目等於其度數減一,那麼,我們便可以得出一棵有根樹的dfs序列為:。進一步,我們可以令 ,那麼對於不同的根,其對應樹的方案數就是 res*deg[root],也即res就是所謂的公共部分。
接著,我們來考慮這道題目。由於題目要求是字典序比給定的要小,而且是dfs序,所以我麼考慮按照它給定的順序進行dfs,逐位計算種類數。初始根的時候,我們先利用上面的公式,計算所有以編號小於B[0]的點為根的方案。然後開始dfs,當我們走到樹上的x節點,序列的第i位的時候,在x的所有的可選兒子中,查詢有多少個的編號小於B[i]。不妨設此時恰好有t個可選兒子的編號小於B[i],那麼這個點的貢獻就是
再整理一下,就是遇到一個點首先計算貢獻,然後順著往下走的同時,把這個點在它父親的可選點中刪除,對應的res也要少一個產生貢獻的自由點。由於這是dfs,所以經過某一個點之後,後面可能再次經過,然後第二次經過的時候,我們還是需要對於一個定值,看有多少可行的點比它小,這個過程如果沒有高效的方法,複雜度將會比較大。
由此,我們就需要一個,支援插入、刪除和查詢有多少個點比某一個定值小的資料結構。用一個Treap平衡樹即可,需要一個rank操作。另外,今天還新學到一個pbds庫,裡面有可以用到的紅黑樹(red-black tree),直接就可以支援插入、刪除和查詢rank的操作,可以省去大量的程式碼,值得正式比賽用。具體見程式碼:
#include<bits/stdc++.h> #define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout) #define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0); #define mod 1000000007 #define LL long long #define N 1000010 using namespace std; LL inv[N],fac[N],ans,res; int d[N],b[N],f[N],n,m; vector<int> g[N]; int rt[N],sz; bool flag; struct Treap { #define ls T[i].ch[0] #define rs T[i].ch[1] struct treap{int ch[2],sz,val,cnt,fix;} T[N]; void up(int i) { T[i].sz=T[i].cnt+T[ls].sz+T[rs].sz; } void Rotate(int &x,bool d) { int y=T[x].ch[d]; T[x].ch[d]=T[y].ch[d^1]; T[y].ch[d^1]=x; up(x),up(x=y); } void ins(int &i,int x) { if (!i) { i=++sz;T[i].fix=rand(); T[i].sz=T[i].cnt=1; T[i].val=x;ls=rs=0; return; } T[i].sz++; if (x==T[i].val) {T[i].cnt++;return;} int d=x>T[i].val; ins(T[i].ch[d],x); if (T[T[i].ch[d]].fix<T[i].fix) Rotate(i,d); } void del(int &i,int x) { if (!i) return; if (T[i].val==x) { if (T[i].cnt>1){T[i].cnt--,T[i].sz--;return;} int d=T[ls].fix>T[rs].fix; if (ls==0||rs==0) i=ls+rs; else Rotate(i,d),del(i,x); } else T[i].sz--,del(T[i].ch[x>T[i].val],x); } int rank(int i,int x) { if (!i) return 0; if (T[i].val>x) return rank(ls,x); if (T[i].val==x) return T[ls].sz+1; return rank(rs,x)+T[ls].sz+T[i].cnt; } } treap; void init() { fac[1]=fac[0]=1; inv[1]=inv[0]=1; for(int i=2;i<N;i++) { fac[i]=fac[i-1]*i%mod; inv[i]=(mod-mod/i)*inv[mod%i]%mod; } for(int i=2;i<N;i++) inv[i]=inv[i-1]*inv[i]%mod; } void dfs(int x,int fa) { f[x]=fa; if (fa) treap.ins(rt[fa],x); for(int i=0;i<g[x].size();i++) { int y=g[x][i]; if (y==fa) continue; dfs(y,x); d[y]--; } } void dfs(int x) { if (m>n||flag||!x) return; if (d[x]!=0) { int t=treap.rank(rt[x],b[m+1]-1); //查詢有多少點小於b[m+1] ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod; //計算貢獻 if (f[b[m+1]]!=x) {flag=1;return;} m++; res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--; //改變res,同時x要少一個可選點 treap.del(rt[x],b[m]); dfs(b[m]); //把b[m]從x的可選點中刪除 } else dfs(f[x]); } int main() { init(); IO;int T;cin>>T; while(T--) { cin>>n; res=1,ans=sz=0; for(int i=1;i<=n;i++) cin>>b[i],d[i]=rt[i]=0,g[i].clear(); for(int i=1;i<n;i++) { int x,y; cin>>x>>y; g[x].push_back(y); g[y].push_back(x); d[x]++,d[y]++; } for(int i=1;i<=n;i++) res=res*fac[d[i]-1]%mod; for(int i=1;i<b[1];i++) ans=(ans+res*d[i]%mod)%mod; res=res*d[b[1]]%mod; m=1; flag=0; dfs(b[1],0); dfs(b[1]); cout<<ans<<endl; } return 0; }
然後我們還有pbds庫版本的程式碼,這個更加的簡潔,適合比賽用,但是速度可能就會慢一點點。
#include<bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define mod 1000000007
#define LL long long
#define N 1000010
using namespace std;
using namespace __gnu_pbds;
tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> rbt[N];
LL inv[N],fac[N],ans,res;
int d[N],b[N],f[N],n,m;
vector<int> g[N];
bool flag;
void init()
{
fac[1]=fac[0]=1;
inv[1]=inv[0]=1;
for(int i=2;i<N;i++)
{
fac[i]=fac[i-1]*i%mod;
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
for(int i=2;i<N;i++)
inv[i]=inv[i-1]*inv[i]%mod;
}
void dfs(int x,int fa)
{
f[x]=fa;
rbt[fa].insert(x);
for(int i=0;i<g[x].size();i++)
if (g[x][i]!=fa) dfs(g[x][i],x);
}
void dfs(int x)
{
if (m>n||flag||!x) return;
if (d[x]!=0)
{
int t=rbt[x].order_of_key(b[m+1]);
ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod;
if (f[b[m+1]]!=x) {flag=1;return;} m++;
res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--;
rbt[x].erase(b[m]); dfs(b[m]);
} else dfs(f[x]);
}
int main()
{
init();
IO;int T;cin>>T;
while(T--)
{
cin>>n; res=1,ans=0;
for(int i=1;i<=n;i++)
cin>>b[i],d[i]=0,rbt[i].clear(),g[i].clear();
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
d[x]++,d[y]++;
}
for(int i=1;i<=n;i++)
res=res*fac[d[i]-1]%mod;
for(int i=1;i<b[1];i++)
ans=(ans+res*(LL)d[i]%mod)%mod;
for(int i=1;i<=n;i++)
if (i!=b[1]) d[i]--;
res=res*d[b[1]]%mod; m=1; flag=0;
dfs(b[1],0); dfs(b[1]);
cout<<ans<<endl;
}
return 0;
}