1. 程式人生 > >【NTT】【多項式】洛谷P5158 多項式快速插值(log^2)

【NTT】【多項式】洛谷P5158 多項式快速插值(log^2)

快速插值 O ( N l o g 2 N )

O(N log^2 N) 板子。

話說這程式碼居然比 O ( N l o g 3

N ) O(N log^3 N) 更短更好寫。。。

實測本題 O ( N l o

g 3 N ) O(N log^3 N) 的時間是 O ( N l o g 2 N ) O(N log^2 N) 的5~6倍左右。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define SF scanf
#define PF printf
#define MAXN 3000010
#define MOD 998244353
using namespace std;
void Read(int &x){
	char c;
	while(c=getchar(),c!=EOF&&(c<'0'||c>'9'));
	x=c-'0';
	while(c=getchar(),c!=EOF&&c>='0'&&c<='9')
		x=x*10+c-'0';	
}
char p1[20];
void Print(int x){
	if(x==0){
		putchar('0');
	}
	int tot=0;
	while(x){
		p1[tot++]=x%10+'0';
		x/=10;
	}
	for(int i=tot-1;i>=0;i--)
		putchar(p1[i]);
	putchar(' ');
}
const int G=3;
int n,m,b[MAXN],WN[30],INVW[30];
int buf[MAXN*4];
int *bgn[MAXN],*ncnt=buf;
int fsp(int x,int y){
    int res=1;
    while(y){
        if(y&1)
            res=1ll*res*x%MOD;
        x=1ll*x*x%MOD;
        y>>=1;
    }
    return res;
}
void ntt(int A[],int N,int flag){
    for(int i=1,j=0;i<N;i++){
        for(int d=N;j^=d>>=1,~j&d;);
        if(i<j)
            swap(A[i],A[j]);	
    }
    for(int i=1,id=0;i<N;i<<=1,id++){
    	int wn;
		if(flag==0)
    		wn=WN[id];
    	else
    		wn=INVW[id];
//        int wn=fsp(G,(MOD-1)/(i<<1));
//        if(flag) wn=fsp(wn,MOD-2);
        for(int j=0;j<N;j+=i<<1){
            int w=1;
            for(int k=0;k<i;k++,w=1ll*w*wn%MOD){
                int x=A[j+k],y=1ll*w*A[i+j+k]%MOD;
                A[j+k]=(x+y)%MOD;
                A[i+j+k]=(x-y+MOD)%MOD;
            }
        }
    }
    if(flag) for(int i=0,invN=fsp(N,MOD-2);i<N;i++) A[i]=1ll*A[i]*invN%MOD;
}
void mul(int A[],int N,int B[],int M,int res[]){
    static int A1[MAXN],B1[MAXN],res1[MAXN];
    for(int i=0;i<N;i++)
        A1[i]=A[i];
    for(int i=0;i<M;i++)
        B1[i]=B[i];
    int p=1;
    while(p<N+M) p<<=1;
    ntt(A1,p,0);
    ntt(B1,p,0);
    for(int i=0;i<p;i++)
        res1[i]=1ll*A1[i]*B1[i]%MOD;
    ntt(res1,p,1);
    for(int i=0;i<N+M-1;i++)
        res[i]=res1[i];
    for(int i=0;i<p;i++)
        res1[i]=A1[i]=B1[i]=0;
}
void build_p(int id,int l,int r){
    if(l==r){
        bgn[id]=ncnt;
        bgn[id][0]=MOD-b[l];
        bgn[id][1]=1;
        ncnt+=2;
        return ;
    }
    bgn[id]=ncnt;
    ncnt+=(r-l+2);
    int mid=(l+r)>>1;
    build_p(id<<1,l,mid);
    build_p(id<<1|1,mid+1,r);
    mul(bgn[id<<1],mid-l+2,bgn[id<<1|1],r-mid+1,bgn[id]);
}
void inv(int A[],int N,int B[]){
    if(N==1){
        B[0]=fsp(A[0],MOD-2);
        return ;
    }
    inv(A,(N+1)>>1,B);
    static int tmp2[MAXN],tmp3[MAXN];
    int p=1;
    while(p<N<<1) p<<=1;
    for(int i=0;i<N;i++) tmp2[i]=A[i];
    for(int i=N;i<p;i++) tmp2[i]=0;
    ntt(tmp2,p,0);
    for(int i=(N+1)>>1;i<p;i++) B[i]=0;
    ntt(B,p,0);
    for(int i=0;i<p;i++)
        tmp3[i]=1ll*B[i]*((2ll-1ll*B[i]*tmp2[i]%MOD+MOD)%MOD)%MOD;
    ntt(tmp3,p,1);
    for(int i=0;i<N;i++) B[i]=tmp3[i];
    for(int i=0;i<p;i++) tmp2[i]=tmp3[i]=0;
}
void PolyMod(int A[],int N,int B[],int M,int res[]){
    static int ta[MAXN],tb[MAXN],tmp[MAXN];
    for(int i=0;i<N;i++) ta[i]=A[N-i-1];
    for(int i=0;i<M;i++) tb[i]=B[M-i-1];
    inv(tb,M,tmp);
    for(int i=0;i<N-M+1;i++) tb[i]=tmp[i];
    for(int i=0;i<4*M;i++) tmp[i]=0;
    mul(ta,N,tb,N-M+1,tmp);
    reverse(tmp,tmp+N-M+1);
    mul(B,M,tmp,N-M+1,tmp);
    for(int i=0;i<M-1;i++)
        res[i]=(A[i]-tmp[i]+MOD)%MOD;
    for(int i=0;i<4*(N+1);i++)
        tmp[i]=0;
}
int res[MAXN];
void Multipoint_evaluation(int A[],int N,int id,int l,int r){
	if(N>r-l+1){
		PolyMod(A,N,bgn[id],r-l+2,A);
		N=r-l+1;
	}
    if(l==r){
        int xnow=1;
        res[l]=0;
		for(int i=0;i<N;i++){
            res[l]=(res[l]+1ll*xnow*A[i]%MOD)%MOD;
            xnow=1ll*xnow*b[l]%MOD;
        }
        return ;
    }
    int mid=(l+r)>>1;
    static int tmp4[MAXN],tmp5[MAXN];
    PolyMod(A,N,bgn[id<<1],mid-l+2,tmp4);
    PolyMod(A,N,bgn[id<<1|1],r-mid+1,tmp5);
    for(int i=0;i<mid-l+1;i++) A[i]=tmp4[i];
    for(int i=0;i<r-mid;i++) A[i+mid-l+1]=tmp5[i];
    for(int i=0;i<4*N;i++) tmp4[i]=tmp5[i]=0;
    Multipoint_evaluation(A,mid-l+1,id<<1,l,mid);
    Multipoint_evaluation(A+mid-l+1,r-mid,id<<1|1,mid+1,r);
}
int Gi[MAXN],Gx[MAXN];
void Fast_interpolation(int A[],int y[],int id,int l,int r){
	if(l==r){
		A[l]=1ll*y[l]*fsp(Gi[l],MOD-2)%MOD;	
		return ;
	}
	int mid=(l+r)>>1;
	Fast_interpolation(A,y,id<<1,l,mid);
	Fast_interpolation(A,y,id<<1|1,mid+1,r);
	static int tmp6[MAXN],tmp7[MAXN];
	mul(A+l,mid-l+1,bgn[id<<1|1],r-mid+1,tmp6);
	mul(A+mid+1,r-mid,bgn[id<<1