1. 程式人生 > >【NOIP 校內模擬】T3 忘了是啥名字了(dfs序+樹狀陣列)

【NOIP 校內模擬】T3 忘了是啥名字了(dfs序+樹狀陣列)

對於當前新加入的一條路徑 他產生的貢獻分為兩種

1.另一條路徑的LCA在當前路徑上
2.當前路徑的LCA在另一條上

對於情況1:

可以維護當前點到根節點有多少個LCA,查詢只需查詢u,v,-2*lca(u,v),修改需要對lca的子樹+1

對於情況2:

顯然的樹上差分,查詢就是lca子樹的字首和,修改u++,v++,lca-2

即開兩個樹狀陣列,一個支援單點查詢+區間修改,一個支援單點修改+區間查詢,不嫌麻煩的話可以嘗試線段樹。

需要開棧,某OJ棧空間感人。

#include<bits/stdc++.h>
#define N 1000005
#define M 1000005
#define ll long long
using namespace std;
template<class T>
inline void read(T &x)
{
    x=0; int f=1;
    static char ch=getchar();
    while((!isdigit(ch))&&ch!='-')  ch=getchar();
    if(ch=='-') f=-1,ch=getchar();
    while(isdigit(ch))  x=x*10+ch-'0',ch=getchar();
    x*=f;
}
//1e6
struct Edge
{
    int to,next;
}edge[2*N];
int n,m,tot,first[N];
inline void addedge(int x,int y)
{
    tot++;
    edge[tot].to=y; edge[tot].next=first[x]; first[x]=tot;
}
int up[N][27],depth[N],st[N],sign,ed[N];
ll con[N];
void dfs(int now,int fa)
{
    up[now][0]=fa;
    depth[now]=depth[fa]+1;
    st[now]=++sign;
    for(int i=1;i<=25;i++)  up[now][i]=up[up[now][i-1]][i-1];
    for(int u=first[now];u;u=edge[u].next)
    {
        int vis=edge[u].to;
        if(vis==fa) continue;
        dfs(vis,now);
    }
    ed[now]=sign;
}
inline int getlca(int x,int y)
{
    if(depth[x]<depth[y])   swap(x,y);
    for(int i=25;i>=0;i--) if(depth[up[x][i]]>=depth[y]) x=up[x][i];
    if(x==y)    return x;
    for(int i=25;i>=0;i--) if(up[x][i]!=up[y][i])   x=up[x][i],y=up[y][i];
    return up[x][0];
}
inline int lowbit(int x)
{
    return x&(-x);
}
struct BIT
{
    int n;
    ll tree[N];
    inline void getn(int x)
    {
        n=x;
    }
    inline void update(int x,ll del)
    {
        for(int i=x;i<=n;i+=lowbit(i))  tree[i]+=del;
    }
    inline ll query(int x)
    {
        ll ans=0;
        for(int i=x;i;i-=lowbit(i)) ans+=tree[i];
        return ans;
    }
}bit1,bit2; //區間加單點查  單點加區間查    其實就是差分,普通 bit
int main()
{
    ll size=40<<20;//40M
    __asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));//提交用這個 
    read(n),read(m);
    for(register int i=1;i<n;i++)
    {
        int x,y;
        read(x),read(y);
        addedge(x,y); addedge(y,x);
    }
    dfs(1,0);
    bit1.getn(n); bit2.getn(n);
    ll ans=0;
    //需要分兩種情況討論:其他的lca在這條路徑上  自己的lca在其他路徑上 
    for(int i=1,u,v,lca;i<=m;i++)
    {
        read(u); read(v); lca=getlca(u,v);
        ans=ans+bit1.query(st[u])+bit1.query(st[v])-2*bit1.query(st[lca]);
        ans=ans+bit2.query(ed[lca])-bit2.query(st[lca]-1);
        ans=ans+con[lca];
        con[lca]++;
        bit1.update(st[lca],1); bit1.update(ed[lca]+1,-1);
        bit2.update(st[u],1); bit2.update(st[v],1); bit2.update(st[lca],-2);
    }
    cout<<ans;
    exit(0);
}