【比賽】NOIP2018 保衛王國
阿新 • • 發佈:2019-01-01
DDP模板題
#include<bits/stdc++.h> #define ui unsigned int #define ll long long #define db double #define ld long double #define ull unsigned long long #define ft first #define sd second #define pb(a) push_back(a) #define mp(a,b) std::make_pair(a,b) #define ITR(a,b) for(auto a:b) #define REP(a,b,c) for(register int a=(b),a##end=(c);a<=a##end;++a) #define DEP(a,b,c) for(register int a=(b),a##end=(c);a>=a##end;--a) const int MAXN=100000+10; const ll inf=1e18,vinf=1e12; int n,m,e,beg[MAXN],nex[MAXN<<1],to[MAXN<<1],size[MAXN],hson[MAXN],st[MAXN],ed[MAXN],top[MAXN],fa[MAXN],w[MAXN],cnt; ll f[MAXN][2],all; char type[5]; template<typename T> inline bool chkmin(T &x,T y){return y<x?(x=y,true):false;} template<typename T> inline bool chkmax(T &x,T y){return y>x?(x=y,true):false;} struct Matrix{ ll a[2][2]; Matrix(){ REP(i,0,1)REP(j,0,1)a[i][j]=-inf; }; inline Matrix operator * (const Matrix &A) const { Matrix B; REP(i,0,1)REP(k,0,1)REP(j,0,1)chkmax(B.a[i][j],a[i][k]+A.a[k][j]); return B; }; }; Matrix val[MAXN]; #define Mid ((l+r)>>1) #define ls rt<<1 #define rs rt<<1|1 #define lson ls,l,Mid #define rson rs,Mid+1,r struct Segment_Tree{ Matrix sum[MAXN<<2]; inline void PushUp(int rt) { sum[rt]=sum[ls]*sum[rs]; } inline void Build(int rt,int l,int r) { if(l==r)sum[rt]=val[l]; else Build(lson),Build(rson),PushUp(rt); } inline void Update(int rt,int l,int r,int ps,Matrix k) { if(l==r)sum[rt]=k; else { if(ps<=Mid)Update(lson,ps,k); else Update(rson,ps,k); PushUp(rt); } } inline Matrix Query(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R)return sum[rt]; else { if(R<=Mid)return Query(lson,L,R); else if(L>Mid)return Query(rson,L,R); else return Query(lson,L,R)*Query(rson,L,R); } } }; Segment_Tree T; #undef Mid #undef ls #undef rs #undef lson #undef rson template<typename T> inline void read(T &x) { T data=0,w=1; char ch=0; while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar(); if(ch=='-')w=-1,ch=getchar(); while(ch>='0'&&ch<='9')data=((T)data<<3)+((T)data<<1)+(ch^'0'),ch=getchar(); x=data*w; } template<typename T> inline void write(T x,char ch='\0') { if(x<0)putchar('-'),x=-x; if(x>9)write(x/10); putchar(x%10+'0'); if(ch!='\0')putchar(ch); } template<typename T> inline T min(T x,T y){return x<y?x:y;} template<typename T> inline T max(T x,T y){return x>y?x:y;} inline void insert(int x,int y) { to[++e]=y; nex[e]=beg[x]; beg[x]=e; } inline void dfs1(int x,int p) { int res=0; size[x]=1;fa[x]=p; for(register int i=beg[x];i;i=nex[i]) if(to[i]==p)continue; else { dfs1(to[i],x); size[x]+=size[to[i]]; if(chkmax(res,size[to[i]]))hson[x]=to[i]; } } inline void dfs2(int x,int tp) { top[x]=tp;st[x]=++cnt; val[cnt].a[0][0]=val[cnt].a[0][1]=f[x][0]; val[cnt].a[1][0]=f[x][1]; if(hson[x]) { val[cnt].a[0][0]-=max(f[hson[x]][0],f[hson[x]][1]); val[cnt].a[0][1]=val[cnt].a[0][0]; val[cnt].a[1][0]-=f[hson[x]][0]; dfs2(hson[x],tp);ed[x]=ed[hson[x]]; } else ed[x]=cnt; for(register int i=beg[x];i;i=nex[i]) if(to[i]==fa[x]||to[i]==hson[x])continue; else dfs2(to[i],to[i]); } inline void dfs(int x) { f[x][1]=w[x]; for(register int i=beg[x];i;i=nex[i]) if(to[i]==fa[x])continue; else { dfs(to[i]); f[x][1]+=f[to[i]][0]; f[x][0]+=max(f[to[i]][0],f[to[i]][1]); } } inline void init() { dfs1(1,0);dfs(1);dfs2(1,1); T.Build(1,1,n); } inline void solve(int u,ll v) { Matrix A,B,C; B=T.Query(1,1,n,st[u],st[u]); A=T.Query(1,1,n,st[top[u]],ed[u]); B.a[1][0]+=v; T.Update(1,1,n,st[u],B); while(u) { B=T.Query(1,1,n,st[top[u]],ed[u]); u=fa[top[u]]; if(!u)break; C=T.Query(1,1,n,st[u],st[u]); C.a[0][0]+=max(B.a[0][0],B.a[1][0])-max(A.a[0][0],A.a[1][0]); C.a[0][1]=C.a[0][0]; C.a[1][0]+=B.a[0][0]-A.a[0][0]; A=T.Query(1,1,n,st[top[u]],ed[u]); T.Update(1,1,n,st[u],C); } } inline ll value(int ot1,int ot2) { Matrix A=T.Query(1,1,n,st[1],ed[1]); return max(A.a[0][0],A.a[1][0])+(ot1?0:-vinf)+(ot2?0:-vinf); } int main() { freopen("defense.in","r",stdin); freopen("defense.out","w",stdout); read(n);read(m);scanf("%s",type); REP(i,1,n)read(w[i]),all+=w[i]; REP(i,1,n-1) { int u,v;read(u);read(v); insert(u,v);insert(v,u); } init(); while(m--) { int a,x,b,y;read(a);read(x);read(b);read(y); if((fa[a]==b||fa[b]==a)&&!x&&!y) { puts("-1"); continue; } solve(a,x?-vinf:vinf); solve(b,y?-vinf:vinf); printf("%lld\n",all-value(x,y)); solve(a,x?vinf:-vinf); solve(b,y?vinf:-vinf); } return 0; }