1. 程式人生 > >【XSY2744】信仰聖光 分治FFT 多項式exp 容斥原理

【XSY2744】信仰聖光 分治FFT 多項式exp 容斥原理

getchar span 復雜度 getch con get air nom 多少

題目描述

  有一個\(n\)個元素的置換,你要選擇\(k\)個元素,問有多少種方案滿足:對於每個輪換,你都選擇了其中的一個元素。

  對\(998244353\)取模。

  \(k\leq n\leq 152501\)

題解

吐槽

  為什麽一道FFT題要把\(n\)設為\(150000\)

解法一

  先把輪換拆出來。

  直接DP。

  設\(f_{i,j}\)為前\(i\)個輪換選擇了\(j\)個元素,且每個輪換都選擇了至少一個元素的方案數。
\[ f_{i,j}=\sum_{k=1}^{a_i}f_{i-1,j-k}\binom{a_i}{k} \]
  時間復雜度為\(O(n^2)\),因為枚舉的是第\(i\)

組和前\(i-1\)組的配對,而任意兩個元素之間最多被配對一次。

  可以分治FFT做到\(O(n\log^2 n)\)

解法二

  考慮容斥。

  設\(m\)為輪換個數。

  枚舉有哪些輪換\(S\)中可能有被選中的元素,容斥系數就是\({(-1)}^{m-|S|}\)\(sum\)為這些輪換的大小總和):

  或者枚舉哪些輪換\(S\)中沒有被選中的元素,容斥系數就是\({(-1)}^{|S|}\)
\[ \begin{align} s&=\sum_{S}{(-1)}^{m-|S|}\binom{sum}{k}\s&=\sum_{S}{(-1)}^{|S|}\binom{n-sum}{k}\\end{align} \]


  現在我們要對於每一個\(i\),計算\(f_i=\sum_{S,sum=i}{(-1)}^{|S|}\)

  構造生成函數\(A_i(x)=1-x^{a_i}\),那麽\(F(x)=\prod_{i=1}^mA_i(x)\)

  直接做還是\(O(n\log^2n)\)的。我們需要一些優化。
\[ \begin{align} F(x)&=\prod_{i=1}^m1-x^{a_i}\\ln(F(x))&=\sum_{i=1}^n\ln(1-x^{a_i})\\ln(F(x))&=\sum_{i=1}^n\sum_{j=a_i}-\frac{x^{ja_i}}{j} \end{align} \]


  那麽可以在\(O(n\log n)\)內算出\(\ln(F(x))\),然後\(\exp\)一下。

  時間復雜度:\(O(n\log n)\)

  由於常數過大,所以要用下面那條式子(因為只用計算到\(x^{n-k}\))。

解法一

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
int rd()
{
    int s=0,c;
    while((c=getchar())<'0'||c>'9');
    s=c-'0';
    while((c=getchar())>='0'&&c<='9')
        s=s*10+c-'0';
    return s;
}
const int p=998244353;
const int g=3;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
ll inv[200010];
ll fac[200010];
ll ifac[200010];
int a[200010];
int n,m,k;
int c[200010];
int b[200010];
ll getc(int x,int y)
{
    return fac[x]*ifac[y]%p*ifac[x-y]%p;
}
ll *f[500010];
int len[500010];
int cnt;
int a1[600010];
int a2[600010];
int rev[600010];
void ntt(int *a,int n,int t)
{
    for(int i=1;i<n;i++)
    {
        rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
        if(i>rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(int i=2;i<=n;i<<=1)
    {
        int wn=fp(g,(p-1)/i*(t==1?1:i-1));
        for(int j=0;j<n;j+=i)
        {
            int w=1;
            for(int k=j;k<j+i/2;k++)
            {
                int u=a[k];
                int v=(ll)a[k+i/2]*w%p;
                a[k]=(u+v)%p;
                a[k+i/2]=(u-v)%p;
                w=(ll)w*wn%p;
            }
        }
    }
    if(t==-1)
    {
        int inv=fp(n,p-2);
        for(int i=0;i<n;i++)
            a[i]=(ll)a[i]*inv%p;
    }
}
void solve(int &now,int l,int r)
{
    now=++cnt;
    if(l==r)
    {
        len[now]=min(a[l],k);
        f[now]=new ll[len[now]+1];
        f[now][0]=0;
        for(int i=1;i<=len[now];i++)
            f[now][i]=ifac[i]*ifac[a[l]-i]%p;
        return;
    }
    int ls,rs,mid=(l+r)>>1;
    solve(ls,l,mid);
    solve(rs,mid+1,r);
    len[now]=min(len[ls]+len[rs],k);
    f[now]=new ll[len[now]+1];
    int v=1;
    while(v<=len[ls]+len[rs])
        v<<=1;
    for(int i=0;i<v;i++)
        a1[i]=(i<=len[ls]?f[ls][i]:0);
    for(int i=0;i<v;i++)
        a2[i]=(i<=len[rs]?f[rs][i]:0);
    ntt(a1,v,1);
    ntt(a2,v,1);
    for(int i=0;i<v;i++)
        a1[i]=(ll)a1[i]*a2[i]%p;
    ntt(a1,v,-1);
    for(int i=0;i<=len[now];i++)
        f[now][i]=a1[i];
    delete [] f[ls];
    delete [] f[rs];
}
void solve()
{
//  scanf("%d%d",&n,&k);
    n=rd();
    k=rd();
    for(int i=1;i<=n;i++)
        c[i]=rd();
//      scanf("%d",&c[i]);
    if(k==n)
    {
        printf("1\n");
        return;
    }
    m=0;
    cnt=0;
    memset(b,0,sizeof b);
    memset(a,0,sizeof a);
    for(int i=1;i<=n;i++)
        if(!b[i])
        {
            m++;
            for(int j=i;!b[j];j=c[j])
            {
                b[j]=1;
                a[m]++;
            }
        }
    if(k<m)
    {
        printf("0\n");
        return;
    }
    int rt;
    solve(rt,1,m);
    ll ans=f[rt][k];
    ans=ans*fp(getc(n,k),p-2)%p;
    for(int i=1;i<=m;i++)
        ans=ans*fac[a[i]]%p;
    ans=(ans+p)%p;
    printf("%lld\n",ans);
}
int main()
{
    open("a");
    inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
    for(int i=2;i<=200000;i++)
    {
        inv[i]=-p/i*inv[p%i]%p;
        fac[i]=fac[i-1]*i%p;
        ifac[i]=ifac[i-1]*inv[i]%p;
    }
    int t;
//  scanf("%d",&t);
    t=rd();
    while(t--)
        solve();
    return 0;
}

解法二

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
int rd()
{
    int s=0,c;
    while((c=getchar())<'0'||c>'9');
    s=c-'0';
    while((c=getchar())>='0'&&c<='9')
        s=s*10+c-'0';
    return s;
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
const int p=998244353;
const int g=3;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
ll inv[300010];
ll fac[300010];
ll ifac[300010];
namespace ntt
{
    int rev[600000];
    void ntt(int *a,int n,int t)
    {
        for(int i=1;i<n;i++)
        {
            rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
            if(i>rev[i])
                swap(a[i],a[rev[i]]);
        }
        for(int i=2;i<=n;i<<=1)
        {
            int wn=fp(g,(p-1)/i*(t==1?1:i-1));
            for(int j=0;j<n;j+=i)
            {
                int w=1;
                for(int k=j;k<j+i/2;k++)
                {
                    int u=a[k];
                    int v=(ll)a[k+i/2]*w%p;
                    a[k]=(u+v)%p;
                    a[k+i/2]=(u-v)%p;
                    w=(ll)w*wn%p;
                }
            }
        }
        if(t==-1)
        {
            int inv=fp(n,p-2);
            for(int i=0;i<n;i++)
                a[i]=(ll)a[i]*inv%p;
        }
    }
    void getinv(int *a,int *b,int n)
    {
        if(n==1)
        {
            b[0]=fp(a[0],p-2);
            return;
        }
        getinv(a,b,n>>1);
        static int a1[600000],a2[600000];
        for(int i=0;i<n;i++)
            a1[i]=a[i];
        for(int i=n;i<n<<1;i++)
            a1[i]=0;
        for(int i=0;i<n>>1;i++)
            a2[i]=b[i];
        for(int i=n>>1;i<n<<1;i++)
            a2[i]=0;
        ntt(a1,n<<1,1);
        ntt(a2,n<<1,1);
        for(int i=0;i<n<<1;i++)
            a1[i]=a2[i]*(2-(ll)a1[i]*a2[i]%p)%p;
        ntt(a1,n<<1,-1);
        for(int i=0;i<n;i++)
            b[i]=a1[i];
    }
    void getln(int *a,int *b,int n)
    {
        static int a1[600000],a2[600000];
        for(int i=1;i<n;i++)
            a1[i-1]=(ll)a[i]*i%p;
        a1[n-1]=0;
        getinv(a,a2,n);
        for(int i=n;i<n<<1;i++)
            a1[i]=a2[i]=0;
        ntt(a1,n<<1,1);
        ntt(a2,n<<1,1);
        for(int i=0;i<n<<1;i++)
            a1[i]=(ll)a1[i]*a2[i]%p;
        ntt(a1,n<<1,-1);
        for(int i=1;i<n;i++)
            b[i]=(ll)a1[i-1]*inv[i]%p;
        b[0]=0;
    }
    void getexp(int *a,int *b,int n)
    {
        if(n==1)
        {
            b[0]=1;
            return;
        }
        getexp(a,b,n>>1);
        static int a1[600000],a2[600000],a3[600000];
        for(int i=n>>1;i<n;i++)
            b[i]=0;
        getln(b,a3,n);
        for(int i=0;i<n>>1;i++)
        {
            a1[i]=b[i];
            a2[i]=(a[i+(n>>1)]-a3[i+(n>>1)])%p;
        }
        for(int i=n>>1;i<n;i++)
            a1[i]=a2[i]=0;
        ntt(a1,n,1);
        ntt(a2,n,1);
        for(int i=0;i<n;i++)
            a1[i]=(ll)a1[i]*a2[i]%p;
        ntt(a1,n,-1);
        for(int i=0;i<n>>1;i++)
            b[i+(n>>1)]=a1[i];
    }
}
int a[200010];
int n,m,k;
int c[200010];
int b[200010];
int cnt;
ll ans;
int d[300010];
int s[300010];
int f[300010];
ll getc(int x,int y)
{
    if(y>x||y<0)
        return 0;
    return fac[x]*ifac[y]%p*ifac[x-y]%p;
}
void dfs(int x,int y,int v)
{
    if(x>m)
    {
        ans=(ans+v*getc(y,k))%p;
        return;
    }
    dfs(x+1,y,v);
    dfs(x+1,y+a[x],-v);
}
void solve()
{
//  scanf("%d%d",&n,&k);
    n=rd();
    k=rd();
    for(int i=1;i<=n;i++)
        c[i]=rd();
//      scanf("%d",&c[i]);
    if(k==n)
    {
        printf("1\n");
        return;
    }
    m=0;
    cnt=0;
    memset(b,0,sizeof b);
    memset(a,0,sizeof a);
    for(int i=1;i<=n;i++)
        if(!b[i])
        {
            m++;
            for(int j=i;!b[j];j=c[j])
            {
                b[j]=1;
                a[m]++;
            }
        }
    if(k<m)
    {
        printf("0\n");
        return;
    }
    memset(d,0,sizeof d);
    memset(s,0,sizeof s);
    for(int i=1;i<=m;i++)
        d[a[i]]++;
    for(int i=1;i<=n;i++)
        if(d[i])
            for(int j=1;i*j<=n;j++)
                s[i*j]=(s[i*j]-inv[j]*d[i])%p;
    int l=1;
    while(l<=n-k)
        l<<=1;
    s[0]=1;
    ntt::getexp(s,f,l);
    ans=0;
    for(int i=0;i<=n-k;i++)
        ans=(ans+f[i]*getc(n-i,k))%p;
//      ans=(ans+f[i]*getc(i,k))%p;
    ans=ans*fp(getc(n,k),p-2)%p;
//  if(m&1)
//      ans=-ans;
    ans=(ans+p)%p;
    printf("%lld\n",ans);
}
int main()
{
    open("a");
    inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
    for(int i=2;i<=300000;i++)
    {
        inv[i]=-p/i*inv[p%i]%p;
        fac[i]=fac[i-1]*i%p;
        ifac[i]=ifac[i-1]*inv[i]%p;
    }
    int t;
//  scanf("%d",&t);
    t=rd();
    while(t--)
        solve();
    return 0;
}

【XSY2744】信仰聖光 分治FFT 多項式exp 容斥原理