1. 程式人生 > >愛之箭發射

愛之箭發射

題目描述

小海是弓道部的成員,非常擅長射箭(Love Arrow Shoot)。今天弓道部的練習是要射一棵樹。一棵樹是一個n個點n−1條邊的無向圖,且這棵樹的第ii個點有一個值wi,wi∈[1,m]。每一次小海會射中樹的一條邊,並將這條邊移除。此外,小海定義一棵樹的las值為∑vi∗i,vi為這棵樹中第ii小的wi。現在小海會告訴你她射中的邊的順序,你需要回答每一次她射中的邊所在的樹的las值,之後被射中的邊會被移除。答案mod998244353

題解

首先想到加邊一定比刪邊更好做。那麼我們要將兩顆子樹合併,考慮到我們有一個神奇的M,提醒我們使用權值線段樹。
思考線段樹合併。記錄下sum表示答案,num表示個數,s表示裡面數的和。那麼考慮小的數會把大的數推到後面,也就是如果有num個小的數,那麼原來 s

u m = a 1 + b 2 +
c 3.... sum=a*1+b*2+c*3.... 會變成 a ( 1
+ n u m ) + b ( 2 + n u m ) + c ( 3 + n u m ) . . . . = s u m + s n u m a*(1+num)+b*(2+num)+c*(3+num)....=sum+s*num

於是就可以合併了,邊界是葉子節點

程式碼

#include<bits/stdc++.h>
#define maxn 500005
#define MAXN 1000005
#define mod 998244353
#define INF 0x3f3f3f3f
#define LL long long
#define LO tr[o].lo
#define RO tr[o].ro
#define mid (l+r>>1)
using namespace std;
LL read(){
    LL res,f=1; char c;
    while(!isdigit(c=getchar())) if(c=='-') f=-1; res=(c^48);
    while(isdigit(c=getchar())) res=(res<<3)+(res<<1)+(c^48);
    return res*f;
}
struct TR{
    LL lo,ro,num,sum,s;
}tr[maxn<<4];
LL pre[maxn],tot,root[maxn];
LL ans[maxn];
int insert(LL &o,int l,int r,int w){
    if(!o) o=++tot;
    tr[o].lo=tr[o].ro=tr[o].num=tr[o].s=tr[o].sum=0;
    if(l==r){
        tr[o].num=1;
        tr[o].s=tr[o].sum=w;
        return o;
    }
    if(w<=mid) insert(LO,l,mid,w);
    else insert(RO,mid+1,r,w);
    tr[o].num=tr[LO].num+tr[RO].num;
    tr[o].sum=(tr[LO].sum+tr[RO].sum+tr[LO].num*tr[RO].s%mod)%mod;
    tr[o].s=tr[LO].s+tr[RO].s;
}
int merge(int x,int y){
    if(!x || !y) return x+y;
    tr[x].lo=merge(tr[x].lo,tr[y].lo);
    tr[x].ro=merge(tr[x].ro,tr[y].ro); 
    if(!tr[x].lo && !tr[x].ro){
        tr[x].sum=(tr[x].sum+tr[y].sum+tr[x].num*tr[y].s)%mod;
        tr[x].num=(tr[x].num+tr[y].num)%mod;
        tr[x].s=(tr[x].s+tr[y].s)%mod;
        return x;
    }
    tr[x].sum=(tr[tr[x].lo].sum+tr[tr[x].ro].sum+tr[tr[x].lo].num*tr[tr[x].ro].s)%mod;
    tr[x].num=(tr[tr[x].lo].num+tr[tr[x].ro].num)%mod;
    tr[x].s=(tr[tr[x].lo].s+tr[tr[x].ro].s)%mod;
    return x;
}
int find(int x){return x==pre[x]?x:pre[x]=find(pre[x]);}
int n,m,u[maxn],v[maxn],order[maxn];
int main(){
    n=read(); m=read();
    for(int i=1;i<=n;i++){
        insert(root[i],1,m,read());
        pre[i]=i;
    }
    for(int i=1;i<n;i++) u[i]=read(),v[i]=read();
    for(int i=1;i<n;i++) order[i]=read();
    for(int i=n-1;i;i--){
        int x=find(u[order[i]]),y=find(v[order[i]]);
        root[x]=merge(root[x],root[y]);
        ans[i]=tr[root[x]].sum;
        pre[y]=x;
    }
    for(int i=1;i<n;i++){
        printf("%lld\n",ans[i]);
    }
    return 0;
}