1. 程式人生 > >[BZOJ4555][TJOI2016&HEOI2016]求和(分治FFT)

[BZOJ4555][TJOI2016&HEOI2016]求和(分治FFT)

turn dft size pos algorithm 實現 return printf body

解法一:容易得到遞推式,可以用CDQ分治+FFT

代碼用時:1h 比較順利,沒有低級錯誤。

實現比較簡單,11348ms

#include<cstdio>
#include<algorithm>
#define rep(i,l,r) for (int i=l; i<=r; i++)
typedef long long ll;
using namespace std;

const int N=(1<<18)+100,P=998244353,g=3;
int n,rev[N];
ll inv[N],fac[N],facinv[N],f[N],a[N],b[N];

ll ksm(ll a,ll b){
   ll ans
=1; for (; b; b>>=1,a=a*a%P) if (b & 1) ans=ans*a%P; return ans; } void DFT(ll a[],int n,int f){ rep(i,0,n-1) if (i<rev[i]) swap(a[i],a[rev[i]]); for (int i=1; i<n; i<<=1){ int wn=ksm(g,(f==1) ? (P-1)/(i<<1) : (P-1)-(P-1)/(i<<1)); for (int
p=i<<1,j=0; j<n; j+=p){ int w=1; for (int k=0; k<i; k++,w=1ll*w*wn%P){ int x=a[j+k],y=1ll*w*a[i+j+k]%P; a[j+k]=(x+y)%P; a[i+j+k]=(x-y+P)%P; } } } if (f==-1){ int inv=ksm(n,P-2); rep(i,0,n-1) a[i]=1ll*a[i]*inv%P; } }
void cdq(int l,int r){ if (l==r) return; int mid=(l+r)>>1,lim=r-l+1,n=1,L=0; cdq(l,mid); while (n<lim) n<<=1,L++; rep(i,0,n-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); rep(i,0,n-1) a[i]=b[i]=0; rep(i,l,mid) a[i-l]=f[i]; rep(i,0,r-l) b[i]=facinv[i]; DFT(a,n,1); DFT(b,n,1); rep(i,0,n-1) a[i]=a[i]*b[i]%P; DFT(a,n,-1); rep(i,mid+1,r) f[i]=(f[i]+2*a[i-l])%P; cdq(mid+1,r); } int main(){ freopen("bzoj4555.in","r",stdin); freopen("bzoj4555.out","w",stdout); scanf("%d",&n); inv[1]=1; fac[0]=facinv[0]=1; rep(i,1,n){ if (i!=1) inv[i]=(P-P/i)*inv[P%i]%P; fac[i]=fac[i-1]*i%P; facinv[i]=facinv[i-1]*inv[i]%P; } f[0]=1; cdq(0,n); ll ans=0; rep(i,0,n) ans=(ans+f[i]*fac[i]%P)%P; if (ans<0) ans+=P; printf("%lld\n",ans); return 0; }

解法二:

代碼用時1.5h long long上出了一點問題

整體上說還是比較簡單的。

#include<cstdio>
#include<algorithm>
#define rep(i,l,r) for (int i=l; i<=r; i++)
typedef long long ll;
using namespace std;

const int N=(1<<18)+5,P=998244353,g=3;
int n,rev[N];
ll ans,inv[N],fac[N],facinv[N],f[N],a[N],b[N];

ll ksm(ll a,ll b){
   ll ans=1;
   for (; b; b>>=1,a=a*a%P)
      if (b & 1) ans=ans*a%P;
   return ans;
}

void DFT(ll a[],int n,int f){
   rep(i,0,n-1) if (i<rev[i]) swap(a[i],a[rev[i]]);
   for (int i=1; i<n; i<<=1){
      ll wn=ksm(g,(f==1) ? (P-1)/(i<<1) : (P-1)-(P-1)/(i<<1));
      for (int p=i<<1,j=0; j<n; j+=p){
         ll w=1;
         for (int k=0; k<i; k++,w=w*wn%P){
            ll x=a[j+k],y=w*a[i+j+k]%P;
            a[j+k]=(x+y)%P; a[i+j+k]=(x-y+P)%P;
         }
      }
   }
   if (f==-1){
      int inv=ksm(n,P-2);
      rep(i,0,n-1) a[i]=a[i]*inv%P;
   }
}

int main(){
   scanf("%d",&n); inv[1]=1; fac[0]=facinv[0]=1;
   rep(i,1,n){
      if (i!=1) inv[i]=(P-P/i)*inv[P%i]%P;
      fac[i]=fac[i-1]*i%P;
      facinv[i]=facinv[i-1]*inv[i]%P;
   }
   a[0]=1; b[0]=1; b[1]=n+1;
   rep(i,1,n) a[i]=((i&1)?-1:1)*facinv[i];
   rep(i,2,n) b[i]=(ksm(i,n+1)-1)*inv[i-1]%P*facinv[i]%P;
   ll lim=n+n+1,nn=1,L=0;
   while (nn<lim) nn<<=1,L++;
   rep(i,0,nn-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
   DFT(a,nn,1); DFT(b,nn,1);
   rep(i,0,nn-1) a[i]=a[i]*b[i];
   DFT(a,nn,-1);
   rep(i,0,n) ans=(ans+ksm(2,i)*fac[i]%P*a[i]%P)%P;
   if (ans<0) ans+=P;
   printf("%lld\n",ans);
   return 0;
}

[BZOJ4555][TJOI2016&HEOI2016]求和(分治FFT)