1. 程式人生 > >【2019雅禮集訓】【第一類斯特林數】【NTT&多項式】permutation

【2019雅禮集訓】【第一類斯特林數】【NTT&多項式】permutation

目錄

題意

找有多少個長度為n的排列,使得從左往右數,有a個元素比之前的所有數字都大,從右往左數,有b個元素比之後的所有數字都大。
n<=2*10^5,a,b<=n

輸入格式

輸入三個整數n,a,b。

輸出格式

輸出一個整數,表示答案。

思路:

這道題是真的神啊...

首先,根據官方題解的思路,首先有一個n^2的DP:
定義dp[i][j]表示一個長度為i的排列,從前往後數一共有j個數字大於所有排在它前面的數字。
首先有轉移式:
\[dp[i][j]=dp[i-1][j-1]+(i-1)*dp[i-1][j]\]


怎麼理解這個式子呢?
首先,最後的排列一定是這樣一個形式:

中間那個是最大值(n),那麼前面第一個位置到第二個位置之間不能放任意一個數;第二個位置與第三個位置之間能夠放1~ 4之間的數;第三個與第四個之間能夠放4~9...我們能夠發現,相鄰兩個位置能夠放的數一定是小於前一個位置的。那麼我們就可以根據選中的關鍵點(如上圖中的1,4,9)將前半部分的序列分為幾部分,每一部分的代表元素為這一部分中的數字的最大值。

例如:1-423-95678,就可以看成是3個部分。代表元素分別是1,4,9。

根據定義,現在考慮的是一共有i個數字,分成了j段。考慮加入一個新的最小的數字,考慮它放在哪裡:

  1. 放在開頭,自己成為一個新的部分,就由dp[i-1][j-1]轉移而來。
    2.因為是最小的,所以可以放在之前的所有已經存在的部分中,那麼有(i-1)中方案,就由(i-1)*dp[i-1][j]轉移而來。

這樣子就有了dp的轉移式,很顯然最後的答案就是:
\[Ans=\sum_{i=1}^{n}(dp[i][a-1]*dp[n-i-1][b-1])*C_{n-1}^{i-1}\]
其實就是列舉最大值的位值i,然後從剩下的n-1中選出i-1個,再將這i-1個數字分為a-1個部分,後面的n-i-1個位置分為b-1個部分。那就有\(O(n^2)\)的演算法了。


有了上述的式子之後,接下來就比較好處理了。

仔細觀察dp的轉移式,我們會驚奇的發現它竟然和第一類斯特林數的遞推式是一樣的。也就是:
\[dp[i][j]=[^{i}_{j}]\]
第一類斯特林數是將i個數分成j個圓排列的方案數(忽略順序的前提下)。而我們可以把之前定義的"部分"每一個都看成是一個圓排列,每一個都把其中最大的值通過圓排列轉到這一部分開頭的位置,就完美的對應上了。

而在全域性看,我們可以先將n-1個數字分配當a+b-2個圓排列裡面去,然後再將這a+b-2分成左邊a-1個和右邊b-1個,就是簡單組合數了。那麼答案式就可以進一步化簡為:
\[Ans=[_{a+b-2}^{\ \ n-1}]*C_{a+b-2}^{a-1}\]
這樣子我們的主要問題就變為求前面那個斯特林數就可以了。


而如何快速求斯特林數又是另外一個問題了...
感覺這裡寫一遍的話,好像有些冗長了。就在下面貼了一個連結(還在寫:)),在裡面我會盡可能詳細的講解如何快速求解S(n,k)。

程式碼

#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAXN 600000
#define MO 998244353
#define G 3
using namespace std;
int seq[MAXN+5];
int n,a,b;
int fact[MAXN+5],inv[MAXN+5];
int PowMod(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)
            ret=1LL*ret*x%MO;
        x=1LL*x*x%MO;
        y>>=1;
    }
    return ret;
}
void Prepare()
{
    fact[0]=1;
    for(int i=1;i<=MAXN;i++)
        fact[i]=1LL*fact[i-1]*i%MO;
    inv[MAXN]=PowMod(fact[MAXN],MO-2);
    for(int i=MAXN-1;i>=0;i--)
        inv[i]=1LL*inv[i+1]*(1LL*i+1LL)%MO;
}
void Reverse(int A[],int deg)
{
    for(int i=0;i<deg/2;i++)
        swap(A[i],A[deg-i-1]);
}
void NTT(int P[],int len,int oper)
{
    for(int i=1,j=0;i<len-1;i++)
    {
        for(int s=len;j^=s>>=1,~j&s;);
        if(i<j) swap(P[i],P[j]);
    }
    int unit,unit_p0;
    for(int d=0;(1<<d)<len;d++)
    {
        int m=(1<<d),m2=m*2;
        unit_p0=PowMod(G,(MO-1)/m2);
        if(oper==-1)
            unit_p0=PowMod(unit_p0,MO-2);
        for(int i=0;i<len;i+=m2)
        {
            unit=1;
            for(int j=0;j<m;j++)
            {
                int &P1=P[i+j+m],&P2=P[i+j];
                int t=1LL*unit*P1%MO;
                P1=((1LL*P2-1LL*t)%MO+MO)%MO;
                P2=(1LL*P2+1LL*t)%MO;
                unit=1LL*unit*unit_p0%MO;
            }
        }
    }
    if(oper==-1)
    {
        int inv=PowMod(len,MO-2);
        for(int i=0;i<len;i++)
            P[i]=1LL*P[i]*inv%MO;
    }
}
void Mul(int ret[],int _x[],int l1,int _y[],int l2)
{
    static int RET[MAXN+5],X[MAXN+5],Y[MAXN+5];
    int len=1;
    while(len<l1+l2)    len<<=1;
    copy(_x,_x+l1,X);copy(_y,_y+l2,Y);
    fill(X+l1,X+len,0);fill(Y+l2,Y+len,0);
    NTT(X,len,1);NTT(Y,len,1);
    for(int i=0;i<len;i++)
        RET[i]=1LL*X[i]*Y[i]%MO;
    NTT(RET,len,-1);
    copy(RET,RET+l1+l2,ret);
}
void Get(int deg,int A[],int B[])
{
    static int tmpA[MAXN+5],tmpB[MAXN+5];
    int len=deg/2;
    for(int i=0;i<len+1;i++)
        tmpA[i]=1LL*PowMod(len,i)*inv[i]%MO;
    fill(tmpA+len+1,tmpA+deg+1,0);
    for(int i=0;i<len+1;i++)
        tmpB[i]=1LL*fact[i]*A[i]%MO;
    fill(tmpB+len+1,tmpB+deg+1,0);
    Reverse(tmpA,len+1);
    Mul(tmpA,tmpA,len+1,tmpB,len+1);
    for(int i=0;i<=len;i++)
        tmpA[i]=1LL*tmpA[i+len]*inv[i]%MO;
    copy(tmpA,tmpA+len+1,B);
}
void Solve(int deg,int B[])
{
    static int tmpB[MAXN+5];
    if(deg==1)
    {
        B[1]=1;
        return;
    }
    Solve(deg/2,B);
    int hf=deg/2;
    copy(B,B+hf+1,tmpB);
    fill(tmpB+hf+1,tmpB+deg+1,0);
    Get(deg-deg%2,tmpB,tmpB+hf+1);
    Mul(B,tmpB,hf+1,tmpB+hf+1,hf+1);
    if(deg%2==1)
        for(int i=deg;i>=1;i--)
            B[i]=(1LL*B[i]*(1LL*deg-1LL)%MO+1LL*B[i-1])%MO;
}
int C(int x,int y)
{
    return 1LL*fact[x]*inv[y]%MO*inv[x-y]%MO;
}
int main()
{
    Prepare();
    scanf("%d %d %d",&n,&a,&b);
    if(n==1&&a==1&&b==1)
//注意加特判,也可以通過改變Solve底層的返回條件來相容這種情況
        printf("1\n");
    else if((n>1&&a==1&&b==1)||a==0||b==0)
        printf("0\n");
    else
    {
        Solve(n-1,seq);//快速求第一類斯特林數(nlogn)
        int part1=seq[a+b-2];
        int ans=1LL*C(a+b-2,a-1)*part1%MO;
        printf("%d\n",ans);
    }
    return 0;
}
/*
2 2 1

5 2 2

*/