1. 程式人生 > >Codeforces 623E Transforming Sequence 【FFT】

Codeforces 623E Transforming Sequence 【FFT】

題目描述及題解

題解就到這位大佬的部落格上看吧。。說得很清楚。。
然而菜爆了的我還是調了3個小時。。。所以來說說實現細節。。

  1. nlogn預處理單位根 ω \omega ,遞推乘的話精度會炸
  2. mod是1e9+7,常規的FFT會炸long long,需要拆係數:
    多項式A和B相乘,把 a
    i , b i a_i,b_i
    拆成 k M
    + p kM+p
    的形式( M = m o
    d M=\sqrt{mod}
    )
    再分別做 k [ a i ] k [ b i ] k [ a i ] p [ b i ] p [ a i ] k [ b i ] p [ a i ] p [ b i ] k[a_i]與k[b_i],k[a_i]與p[b_i],p[a_i]與k[b_i],p[a_i]與p[b_i] 的卷積,然後合併答案
  3. static 是靜態陣列,函式裡的static a=1在第二次呼叫時不會重新賦值。。。
#include<cstdio>
#include<cmath>
#include<algorithm>
#define LL long long
#define maxn 120005
using namespace std;
const double Pi = acos(-1);
const int mod = 1e9+7, M1 = 31623, M2 = 14122;
struct complex
{
	double r,i;
	complex(double _r=0,double _i=0):r(_r),i(_i){}
	complex operator + (const complex &t)const{return complex(r+t.r,i+t.i);}
	complex operator - (const complex &t)const{return complex(r-t.r,i-t.i);}
	complex operator * (const complex &t)const{return complex(r*t.r-i*t.i,r*t.i+i*t.r);}
	complex conj(){return complex(r,-i);}
}w[16][maxn/2];
void change(complex *a,int len)
{
	for(int i=1,j=len/2,k;i<len-1;i++)
	{
		if(i<j) swap(a[i],a[j]);
		for(k=len/2;j>=k;j-=k,k>>=1);
		j+=k;
	}
}
int m,len,f[maxn],g[maxn],p2[maxn];
LL n,fac[maxn],inv[maxn];
inline void fft(complex *a,int flg)
{
	change(a,len);
	for(int i=2,o=0;i<=len;i<<=1,o++)
		for(int j=0;j<len;j+=i)
			for(int k=j;k<j+i/2;k++)
			{
				complex u=a[k],v=(flg==1?w[o][k-j]:w[o][k-j].conj())*a[k+i/2];
				a[k]=u+v,a[k+i/2]=u-v;
			}
	if(flg==-1) for(int i=0;i<len;i++) a[i].r/=len;
}
void calc(int *A,int *B,int *ret)
{
	static complex sta[2][2][maxn];
	for(int i=0;i<len;i++)
		if(i<=m)
		{
			sta[0][0][i]=A[i]/M1,sta[0][1][i]=A[i]%M1;
			sta[1][0][i]=B[i]/M1,sta[1][1][i]=B[i]%M1;
		}
		else sta[0][0][i]=sta[0][1][i]=sta[1][0][i]=sta[1][1][i]=0;
	fft(sta[0][0],1),fft(sta[0][1],1),fft(sta[1][0],1),fft(sta[1][1],1);
	static complex rt[3][maxn];
	for(int i=0;i<len;i++)
	{
		rt[0][i]=sta[0][1][i]*sta[1][1][i];
		rt[1][i]=sta[0][0][i]*sta[1][1][i]+sta[0][1][i]*sta[1][0][i];
		rt[2][i]=sta[0][0][i]*sta[1][0][i];
	}
	fft(rt[0],-1),fft(rt[1],-1),fft(rt[2],-1);
	for(int i=0;i<len;i++) ret[i]=(llround(rt[0][i].r)%mod+llround(rt[1][i].r)%mod*M1%mod+llround(rt[2][i].r)*M2%mod)%mod;
}
void solve(int *A,int *B,int cnt,int *ret)
{
	static int tmp[2][maxn];
	int sp2=1;
	for(int i=0;i<len;i++,sp2=1ll*sp2*p2[cnt]%mod) 
		tmp[0][i]=1ll*A[i]*sp2%mod,tmp[1][i]=B[i];
	calc(tmp[0],tmp[1],ret);
}
void FAC_INV(int N)
{
	fac[0]=fac[1]=inv[0]=inv[1]=p2[0]=1,p2[1]=2;
	for(int i=2;i<=N;i++) fac[i]=fac[i-1]*i%mod,inv[i]=(mod-mod/i)*inv[mod%i]%mod,p2[i]=p2[i-1]*2%mod;
	for(int i=2;i<=N;i++) inv[i]=inv[i]*inv[i-1]%mod;
}
int main()
{
	scanf("%lld%d",&n,&m);
	if(n>m) return puts("0"),0;
	FAC_INV(m);
	len=1;while(len<2*m+1) len<<=1;
	for(int i=2,k=0;i<=len;i<<=1,k++)
		for(int j=0;j<i/2;j++) w[k][j]=complex(cos(2*Pi*j/i),sin(2*Pi*j/i));
	for(int i=1;i<=m;i++) g[i]=inv[i];
	f[0]=1;
	int cnt=1,ans=0;
	for(;n;n>>=1,solve(g,g,cnt,g),cnt<<=1) if(n&1) solve(f,g,cnt,f);
	for(int i=1;i<=m;i++) ans=(ans+f[i]*fac[m]%mod*inv[m-i]%mod)%mod;
	printf("%d",ans);
}