HAOI2018 染色
阿新 • • 發佈:2018-12-19
一道推柿子題目
非常鍛鍊思維能力.
題目連結
首先,顏色次數顯然不能多於\(lim=min(\frac{n}{s},m)\)個.
由於問的是恰好為\(k\)個,我們定義\(f_i\)表示出現次數為\(s\)的顏色恰好為\(k\)個的方案數.
那麼,\(ans=\sum_{i=0}^{lim}w_if_i\)
我們用二項式反演的思想,令\(g_i\)表示至少有\(i\)個顏色出現次數為\(s\)次.
\(g_i\)如何求呢?
顯然,我們首先選出\(i*s\)個格子,再選出\(i\)種顏色,然後對其進行排列,剩下的格子用剩下的顏色亂填即可.
\(g_i=C_n^{i*s}*C_m^i*\frac{(s*i)!}{(s!)^i}*(m-i)^{n-i*s}\)
這裡四個乘數對應上文所述,應該很好理解.
顯然,每個\(f_i\)會在\(g_{j,j\leq i}\)中計算\(C_j^i\)次.
那麼,\(g_i=\sum_{j=i}^{lim}C_j^if_i\)
接下來就是二項式反演環節.
\(f_i=\sum_{j=i}^{lim}(-1)^{j-i}C_j^ig_j\)
我們回憶一下\(ans\)的計算.
\(ans=\sum_{i=0}^{lim}w_if_i\)
\(=\sum_{i=0}^{lim}w_i\sum_{j=i}^{lim}(-1)^{j-i}C_j^ig_j\)
我們把\(C_j^i\)拆掉得
\(ans=\sum_{i=0}^{lim}w_i\sum_{j=i}^{lim}(-1)^{j-i}\frac{j!}{i!(j-i)!}g_j\)
把\(i!\)提出來得\(ans=\sum_{i=0}^{lim}\frac{w_i}{i!}\sum_{j=i}^{lim}\frac{(-1)^{j-i}}{(j-i)!}j!g_j\)
我們令多項式\(A=\sum_{i=0}^{lim}\frac{(-1)^i}{i!}x^i,B=\sum_{i=0}^{lim}i!g_i\)
把\(A\)反轉,然後\(A*B\)就是答案.
這個用\(NTT\)優化即可.
時間複雜度\(O(n+m*log\ m)\)
程式碼如下
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #include<vector> #define N (600010) #define M (10000010) #define P (1004535809) #define rg register int typedef long double ld; typedef long long LL; typedef unsigned long long ull; using namespace std; inline char read(){ static const int IN_LEN=1000000; static char buf[IN_LEN],*s,*t; return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++); } template<class T> inline void read(T &x){ static bool iosig; static char c; for(iosig=false,c=read();!isdigit(c);c=read()){ if(c=='-')iosig=true; if(c==-1)return; } for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0'); if(iosig)x=-x; } inline char readchar(){ static char c; for(c=read();!isalpha(c);c=read()) if(c==-1)return 0; return c; } const int OUT_LEN = 10000000; char obuf[OUT_LEN],*ooh=obuf; inline void print(char c) { if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf; *ooh++=c; } template<class T> inline void print(T x){ static int buf[30],cnt; if(x==0)print('0'); else{ if(x<0)print('-'),x=-x; for(cnt=0;x;x/=10)buf[++cnt]=x%10+48; while(cnt)print((char)buf[cnt--]); } } inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);} int n,m,S,w[N],t,Lim,len,rev[N]; LL mi[40],iv[40],ans,f[N],g[N],jc[M],inv[M]; LL ksm(LL a,int p){ LL res=1; while(p){ if(p&1)res=(res*a)%P; a=(a*a)%P,p>>=1; } return res; } void NTT(LL *a,int tp){ for(int i=0;i<Lim;i++) if(i<rev[i])swap(a[i],a[rev[i]]); for(int i=1,s=1;i<Lim;i<<=1,s++){ LL w=(tp>0)?mi[s]:iv[s]; for(int R=i<<1,j=0;j<Lim;j+=R){ LL p=1; for(int k=j;k<j+i;k++,p=p*w%P){ LL x=a[k],y=p*a[k+i]%P; a[k]=(x+y)%P,a[k+i]=(x-y+P)%P; } } } if(tp==-1){ LL dv=ksm(Lim,P-2)%P; for(int i=0;i<Lim;i++)a[i]=a[i]*dv%P; } } LL C(int n,int m){return jc[n]*inv[n-m]%P*inv[m]%P;} int main(){ read(n),read(m),read(S),jc[0]=inv[0]=inv[1]=1,t=min(n/S,m); for(int i=0;i<=m;i++)read(w[i]); for(int i=1;(1ll<<i)<=P;i++)mi[i]=ksm(3,(P-1)/(1<<i)),iv[i]=ksm(mi[i],P-2)%P; for(int i=1;i<=max(n,m);i++)jc[i]=jc[i-1]*i%P; for(int i=2;i<=max(n,m);i++)inv[i]=inv[P%i]*(P-P/i)%P; for(int i=1;i<=max(n,m);i++)inv[i]=inv[i-1]*inv[i]%P; for(int i=0;i<=t;i++)g[i]=jc[i]*C(m,i)%P*C(n,i*S)%P*jc[S*i]%P*ksm(ksm(jc[S],i),P-2)%P*ksm(m-i,n-i*S)%P; for(int i=0;i<=t;i++)f[t-i]=(i&1)?P-inv[i]:inv[i]; for(Lim=1;Lim<=t+t;Lim<<=1)len++; for(int i=0;i<Lim;i++)rev[i]=((rev[i>>1])>>1)|((i&1)<<(len-1)); NTT(f,1),NTT(g,1); for(int i=0;i<Lim;i++)f[i]=f[i]*g[i]%P; NTT(f,-1); for(int i=0;i<=t;i++)ans=(ans+1ll*w[i]*f[i+t]%P*inv[i]%P)%P; printf("%lld\n",ans); }