1. 程式人生 > >bzoj 5093 圖的價值 —— 第二類斯特林數+NTT

bzoj 5093 圖的價值 —— 第二類斯特林數+NTT

題目:https://www.lydsy.com/JudgeOnline/problem.php?id=5093

每個點都是等價的,從點的貢獻來看,得到式子:

\( ans = n * \sum\limits_{d=0}^{n-1} d^{k} * 2^{C_{n-1}^{2}} * C_{n-1}^{d} \)

使用 \( n^{k} = \sum\limits_{i=0}^{k} S(k,i) * i! *C_{n}^{i} \)

得到 \( ans = n * \sum\limits_{d=0}^{n-1} 2^{C_{n-1}^{2}} * C_{n-1}^{d} * \sum\limits_{j=0}^{k} S(k,j) * j! * C_{d}^{j} \)

此時不要把組合數拆成階乘!雖然拆成階乘可以消去 \( d! \),但如果不消去,放在一起可以得到新的組合意義;

\( ans = n * 2^{C_{n-1}^{2}} * \sum\limits_{j=0}^{k} S(k,j) * j! * \sum\limits_{d=0}^{n-1} C_{n-1}^{d} * C_{d}^{j} \)

而 \( \sum\limits_{d=0}^{n-1} C_{n-1}^{d} * C_{d}^{j} \) 表示從 \( n-1 \) 個人裡選 \( d \) 個人,再從 \( d \) 個人裡選 \( j \) 個人;

其實就是從 \( n-1 \) 個人裡選 \( j \) 個人,剩下的人隨便選,即 \( C_{n-1}^{j} * 2^{n-1-j} \)

所以 \( ans = n * 2^{C_{n-1}^{2}} * \sum\limits_{j=0}^{k} S(k,j) * j! * C_{n-1}^{j} * 2^{n-1-j} \)

而通過 \( S(n,m) = \frac{1}{m!} \sum\limits_{k=0}^{m} C_{m}^{k} * (m-k)^{n} * (-1)^{k} \) (列舉 \( k \) 個空組,最後除去 \( m \) 組的排列)

即 \( S(n,m) = \sum\limits_{k=0}^{m} \frac{(m-k)^{n}}{(m-k)!} * \frac{(-1)^{k}}{k!} \)

可以用NTT求出一行的第二類斯特林數,也就是求出 \( S(k,i) \)

然後把 \( C_{n-1}^{j} \) 拆開約分,上下都只有 \( k \) 級別,預處理即可;

還是要注意次數是對 \( mod-1 \) 取模。

程式碼如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=2e5+5,xm=(1<<19),mod=998244353;
int n,m,lim,a[xm],b[xm],rev[xm],jc[xn],jcn[xn],jd[xn];
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
ll pw(ll a,ll b)
{
  ll ret=1; a=a%mod; b=b%(mod-1);
  for(;b;b>>=1,a=(a*a)%mod)if(b&1)ret=(ret*a)%mod;
  return ret;
}
void init()
{
  jc[0]=1;
  for(int i=1;i<=m;i++)jc[i]=(ll)jc[i-1]*i%mod;
  jcn[m]=pw(jc[m],mod-2);
  for(int i=m-1;i>=0;i--)jcn[i]=(ll)jcn[i+1]*(i+1)%mod;
  jd[0]=1;
  for(int j=1;j<=m;j++)jd[j]=(ll)jd[j-1]*(n-j)%mod;
}
void ntt(int *a,int tp)
{
  for(int i=0;i<lim;i++)
    if(i<rev[i])swap(a[i],a[rev[i]]);
  for(int mid=1;mid<lim;mid<<=1)
    {
      int len=(mid<<1),wn=pw(3,tp==1?(mod-1)/len:(mod-1)-(mod-1)/len);
      for(int j=0;j<lim;j+=len)
    for(int k=0,w=1;k<mid;k++,w=(ll)w*wn%mod)
      {
        int x=a[j+k],y=(ll)w*a[j+mid+k]%mod;
        a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y);
      }
    }
  if(tp==1)return; int inv=pw(lim,mod-2);
  for(int i=0;i<lim;i++)a[i]=(ll)a[i]*inv%mod;
}
int main()
{
  scanf("%d%d",&n,&m); init();
  lim=1; int l=0;
  while(lim<=m+m)lim<<=1,l++;
  for(int i=0;i<lim;i++)rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1)));
  for(int i=0;i<=m;i++)a[i]=(ll)pw(i,m)*jcn[i]%mod;
  for(int i=0;i<=m;i++)b[i]=upt((i&1?-1:1)*jcn[i]);
  ntt(a,1); ntt(b,1);
  for(int i=0;i<lim;i++)a[i]=(ll)a[i]*b[i]%mod;
  ntt(a,-1);
  int ans=0;
  for(int j=0;j<=m;j++)
    ans=(ans+(ll)a[j]*jc[j]%mod*jd[j]%mod*jcn[j]%mod*pw(2,n-1-j))%mod;
  printf("%lld\n",(ll)n*pw(2,((ll)(n-1)*(n-2)/2))%mod*ans%mod);
  return 0;
}