1. 程式人生 > >2018.11.16 bzoj4827: [Hnoi2017]禮物(ntt)

2018.11.16 bzoj4827: [Hnoi2017]禮物(ntt)

傳送門
n t t ntt 入門題。
考慮展開要求的式子 i =

0 n 1 ( x i
y i c ) 2 \sum_{i=0}^{n-1}(x_i-y_i-c)^2
=> i = 0 n 1 ( x i 2 + y i 2 + c 2 2 c ( x i y i ) 2 x i y i ) \sum_{i=0}^{n-1}(x_i^2+y_i^2+c^2-2c(x_i-y_i)-2x_iy_i)
s u m = i = 0 n 1 x i y i sum=\sum_{i=0}^{n-1}x_i-y_i
=> ( i = 0 n 1 ( x i 2 + y i 2 ) ) + min { n c 2 2 s u m c } 2 max { i = 0 n 1 x i y i } (\sum_{i=0}^{n-1}(x_i^2+y_i^2))+\min\{nc^2-2sum*c\}-2*\max\{ \sum_{i=0}^{n-1}x_iy_i\}
化到這一步的時候我卡了一下。
考慮第一坨可以直接算,第二坨可以取二次函式對稱軸附近求極值。
關鍵在於第三坨的極值。
這個怎麼求呢?
考慮將 b b 陣列翻轉。
這樣對 a , b a,b 進行 n t t ntt ,然後得到的新數列 { z n } \{z_n\} 中的 z i + z n + i z_i+z_{n+i} 就對應著數列 a n , b n a_n,b_n 的一種對齊方法的和。
因此只需要在 n t t ntt 之後給 z i + z n + i z_i+z_{n+i} 取個 max \max 就行了。
程式碼:

#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int ans=0;
	char ch=getchar();
	while(!isdigit(ch))ch=getchar();
	while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
	return ans;
}
typedef long long ll;
const int mod=998244353,N=2e5+5;
int n,m,a[N],b[N],pos[N],ans=0,lim=1,tim=0,mx=0,sum=0;
inline int ksm(int a,int p){int ret=1;for(;p;p>>=1,a=(ll)a*a%mod)if(p&1)ret=(ll)ret*a%mod;return ret;}
inline void ntt(int a[],int type){
	for(int i=0;i<lim;++i)if(i<pos[i])swap(a[i],a[pos[i]]);
	int mult=(mod-1)>>1,typ=type==1?3:(mod+1)/3;
	for(int wn,mid=1;mid<lim;mid<<=1,mult>>=1){
		wn=ksm(typ,mult);
		for(int len=mid<<1,j=0;j<lim;j+=len){
			int w=1;
			for(int k=0;k<mid;++k,w=(ll)w*wn%mod){
				int a0=a[j+k],a1=(ll)a[j+k+mid]*w%mod;
				a[j+k]=(a0+a1)%mod,a[j+k+mid]=(a0-a1+mod)%mod;
			}
		}
	}
}
inline int calc(int x){return (n+1)*x*x-2*sum*x;}
int main(){
	freopen("lx.in","r",stdin);
	n=read()-1,m=read();
	for(int i=0;i<=n;++i)a[i]=read(),ans+=a[i]*a[i],sum+=a[i];
	for(int i=n;~i;--i)b[i]=read(),ans+=b[i]*b[i],sum-=b[i];
	while(lim<=n*2)lim<<=1,++tim;
	int p=sum/(n+1);
	ans+=min(min(calc(p-1),calc(p)),calc(p+1));
	for(int i=0;i<lim;++i)pos[i]=(pos[i>>1]>>1)|((i&1)<<(tim-1));
	ntt(a,1),ntt(b,1);
	for(int i=0;i<lim;++i)a[i]=(ll)a[i]*b[i]%mod;
	ntt(a,-1);
	int inv=ksm(lim,mod-2);
	for(int i=0;i<=n*2;++i)a[i]=(ll)a[i]*inv%mod;
	for(int i=0;i<=n;++i)mx=max(mx,a[i]+a[i+n+1]);
	cout<<ans-2*mx;
	return 0;
}