1. 程式人生 > >洛谷P3899 [湖南集訓]談笑風生(線段樹合併)

洛谷P3899 [湖南集訓]談笑風生(線段樹合併)

題目連結

首先分析一下題意,會發現有這麼一句話

2.a 和 b 都比 c 不知道高明到哪裡去了;

思考一下:如果aacc高,說明aacc的祖先,同樣bbcc的祖先 那麼顯然**aabb肯定都在cc到根的路徑上**,所以要麼aabb的祖先,要麼bbaa的祖先

先來思考bbaa的祖先的情況:

首先cc肯定是在aa的子樹裡,有size[a]1size[a]-1個 然後bb的個數有min(deep[a]1,k)min(deep[a]-1,k)個 所以這一部分的貢獻是(size[a]1)min

(deep[a]1,k)(size[a]-1)*min(deep[a]-1,k)

接著是aabb的祖先的情況

bbaa子樹中距離aa點深度kk以內的點,每個cc點從bb點的子樹中取。 這一部分的貢獻是共有size[b]1(deep[b]deep[a]<=k)\sum size[b]-1(deep[b]-deep[a]<=k)

顯然第一部分dfs的時候O(1)O(1)隨便搞搞就可以了,重點在第二部分

我們在每個點弄一顆線段樹,以深度為下標,記錄這個點子樹中每個深度對應的c

c點的個數

到時候只需要查詢deep[a]+1deep[a]+kdeep[a]+1 \rightarrow deep[a]+k之間的權值和就行了

為了不MLE,可以考慮在dfs的時候將子節點線段樹合併到父節點線段樹上,然後就寫完了

注意要開下longlong…


#include<bits/stdc++.h>
#define lson tr[now].l
#define rson tr[now].r
using namespace std;

struct tree
{
    long long sum;
    int l,r;
}tr[
20000010]; struct op { int k,id; }; int n,m; int rt[300010],cnt,deep[300010]; long long ans[300010],size[300010]; vector<int> g[300010]; vector<op> gg[300010]; int dfs(int now,int fa,int dep) { deep[now]=dep; size[now]=1; rt[now]=++cnt; for(int i=0;i<g[now].size();i++) { if(g[now][i]==fa) continue; dfs(g[now][i],now,dep+1); size[now]+=size[g[now][i]]; } } int push_up(int now) { tr[now].sum=tr[lson].sum+tr[rson].sum; } int insert(int &now,int l,int r,int pos,int val) { if(!now) now=++cnt; if(l==r) { tr[now].sum+=val; return 0; } int mid=(l+r)>>1; if(pos<=mid) { insert(lson,l,mid,pos,val); } else { insert(rson,mid+1,r,pos,val); } push_up(now); } long long query(int now,int l,int r,int ll,int rr) { if(ll>rr) return 0; if(ll<=l&&r<=rr) return tr[now].sum; int mid=(l+r)>>1; if(rr<=mid) { return query(lson,l,mid,ll,rr); } else { if(mid<ll) { return query(rson,mid+1,r,ll,rr); } else { return query(lson,l,mid,ll,mid)+query(rson,mid+1,r,mid+1,rr); } } } int merge(int a,int b,int l,int r) { if(!a) return b; if(!b) return a; if(l==r) { tr[a].sum+=tr[b].sum; return a; } int mid=(l+r)>>1; tr[a].l=merge(tr[a].l,tr[b].l,l,mid); tr[a].r=merge(tr[a].r,tr[b].r,mid+1,r); push_up(a); return a; } int dfs2(int now,int fa) { insert(rt[now],1,300000,deep[now],size[now]-1); for(int i=0;i<g[now].size();i++) { if(g[now][i]==fa) continue; dfs2(g[now][i],now); merge(rt[now],rt[g[now][i]],1,300000); } for(int i=0;i<gg[now].size();i++) { int id=gg[now][i].id; int k=gg[now][i].k; long long sum1=(size[now]-1)*min(deep[now]-1,k); long long sum2=query(rt[now],1,300000,deep[now]+1,min(deep[now]+k,300000)); ans[id]=sum1+sum2; } } int main() { scanf("%d%d",&n,&m); int from,to; for(int i=1;i<n;i++) { scanf("%d%d",&from,&to); g[from].push_back(to); g[to].push_back(from); } int pos,k; for(int i=1;i<=m;i++) { scanf("%d%d",&pos,&k); gg[pos].push_back({k,i}); } dfs(1,0,1); dfs2(1,0); for(int i=1;i<=m;i++) { printf("%lld\n",ans[i]); } }