1. 程式人生 > >「日常訓練&知識學習」莫隊演算法(二):樹上莫隊(Count on a tree II,SPOJ COT2)

「日常訓練&知識學習」莫隊演算法(二):樹上莫隊(Count on a tree II,SPOJ COT2)

題意與分析

題意是這樣的,給定一顆節點有權值的樹,然後給若干個詢問,每次詢問讓你找出一條鏈上有多少個不同權值。
寫這題之前要參看我的三個blog:CFR326D2E、CFR340D2E和HYSBZ-1086,然後再看這幾個Blog——
參考A:https://blog.sengxian.com/algorithms/mo-s-algorithm
參考B:https://www.cnblogs.com/oyking/p/4265823.html
參考C:https://blog.csdn.net/u014609452/article/details/50675370
參考D:https://www.cnblogs.com/RabbitHu/p/MoDuiTutorial.html


然後,至少抄的時候心裡稍微明白了一點了23333
接下來,我們具體的分析一下樹上莫隊是如何實現的,以及具體的心路歷程。

具體分析

莫隊演算法的進一步理解

莫隊演算法的核心只有5行:

while(pl < q[i].l) del(a[pl++]);
while(pl > q[i].l) add(a[--pl]);
while(pr < q[i].r) add(a[++pr]);
while(pr > q[i].r) del(a[pr--]);
ans[q[i].id] = sum;

那麼它為啥在排序之後就work了呢?

  1. 左端點所在塊編號確定時,右端點位置保證不下降
    ,所以右端點移動最多造成的時間複雜度是\(O(n)\)的,總共左端點的塊數是\(\sqrt n\)塊,從而總時間複雜度為\(O(n \sqrt n)\)
  2. 左端點所在塊編號變動時,右端點移動最多有\(O(n)\)的時間複雜度(最壞從\(n\)移動回\(1\)),總共\(\sqrt n\)塊,從而總時間複雜度也為\(O(n\sqrt n)\)
  3. 塊內左端點位置每次最多移動\(\sqrt n\),一共\(m\)次詢問,也就是一共移動\(m\)次,總時間複雜度為\(O(m\sqrt n)\)

總上,莫隊演算法具有\(O(n\sqrt n)\)的時間複雜度,相當(在\(n=10^5\)範圍)夠用了。

樹上的遷移

首先為了利用莫隊演算法,我們需要對樹進行分塊。利用BZOJ-1086的分塊方法可以確保比較好的分塊性質:每一塊“相對地”聚在一起,且大小在\([BLOCK\_SIZE, 3BLOCK\_SIZE]\)之間。然後接下來然後我們還是要對所有詢問進行排序。排序依據是左端點所在塊的編號、右端點所在塊的編號、時間。
最後,從樸素莫隊遷移過來的一個重要問題就是:如何移動起點終點?
在序列中,左右端點的移動方式是顯然的,一個端點的移動只有兩個方向——左和右;而它們帶來的影響也是顯然的——區間增加或刪除一個元素。然而樹上莫隊卻不是非常顯然……最佳的(dalao想出來的)方案是:維護一個vis布林陣列,記錄每個節點是否在當前處理的路徑上(LCA非常難辦,我們在維護路徑上的點時不包括LCA,求答案的時候臨時把LCA加上)。每次從上一個詢問\((u_s,v_s)\)轉移到當前詢問\((u_t,v_t)\)時,我們要做的是把路徑\((u_s,u_t)\)\((v_s,v_t)\)上的點的vis逐個取反,同時對應地維護答案。
這樣為啥是對的呢?VFleaKing(一位julao)的部落格中(現在似乎沒法打開了)有證明,證明部分摘錄如下(\(\oplus\)表示類似異或的操作,即節點出現兩次會消去):
(摘者注:\(T(v,u)\)可以理解為一次\((u,v)\)樹上鍊的操作。)
\[T(v, u) = S(root, v) \oplus S(root, u)\](之前的摘者注:顯然等式右側是u到v的路徑上除lca以外的點)
觀察將\(cur_V\)移動到\(target_V\)前後\(T(cur_V, cur_U)\)變化:
\[T(cur_V, cur_U) = S(root, cur_V) \oplus S(root, cur_U) \\ T(target_V, cur_U) = S(root, target_V) \oplus S(root, cur_U)\]
取對稱差:(摘者注:目的是觀察移動區間時出現了什麼變化)
\[T(cur_V, cur_U) \oplus T(target_V, cur_U) = (S(root, cur_V) \oplus S(root, cur_U)) \oplus (S(root, target_V) \oplus S(root, cur_U))\]
由於對稱差的交換律、結合律:
\[T(cur_V, cur_U) \oplus T(target_V, cur_U)= S(root, cur_V) \oplus S(root, target_V)\]
兩邊同時\(\oplus T(cur_V, cur_U)\):(摘者注:清理左邊)
\[T(target_V, cur_U)= T(cur_V, cur_U) \oplus S(root, cur_V) \oplus S(root, target_V)\]
發現最後兩項很爽……哇哈哈(摘者注:利用這種操作的性質)
\[T(target_V, cur_U)= T(cur_V, cur_U) \oplus T(cur_V, target_V)\]
(摘者注:最後可以發現,左端點A->B只需要對現有區間做一個(A,B)的反向操作即可,而這正是莫隊演算法的要求)
這就是樹上莫隊了。是不是很簡單(大霧)

這一題的應用

我們注意到,統計的是鏈上不同點的個數,這個恰恰是可以有這種性質的:\(ans[l,r]\)可以(在\(O(1)\)時間)得到\(ans[l+1,r],ans[l-1,r],ans[l,r+1],ans[l,r-1]\)。於是,這裡就可以應用樹上莫隊演算法了。
具體的實現類似莫隊,只是對區間的處理有點小不同:考慮\((x,y)\)間路徑的資訊,如果\(x\)\(y\)的祖先,那麼所求資訊就為\(x\)\(y\)最後出現的位置之間的資訊(這裡程式碼的實現很簡單:如果二者深度不一樣,就一格一格地跳上來,反正也不需要快速維護,莫隊已經保證了;如果\(x\)\(y\)的祖先——不失一般性,設\(x<y\),那麼到這裡已經結束了);如果\(x\)不是\(y\)的祖先,那麼二者逐步逐步地邊跳到LCA邊維護資訊即可,最後全部的操作結束後再把LCA加上即可(原因上面已經說了)。
預處理有兩個:1)dfs分塊、判斷深度,2)做一個倍增LCA(Tarjan似乎也可以?)。

程式碼

/*
 * Filename: spoj_cot2.cpp
 * Date: 2018-11-13
 */

#include <bits/stdc++.h>

#define INF 0x3f3f3f3f
#define PB emplace_back
#define MP make_pair
#define fi first
#define se second
#define rep(i,a,b) for(repType i=(a); i<=(b); ++i)
#define per(i,a,b) for(repType i=(a); i>=(b); --i)
#define ZERO(x) memset(x, 0, sizeof(x))
#define MS(x,y) memset(x, y, sizeof(x))
#define ALL(x) (x).begin(), (x).end()

#define QUICKIO                  \
    ios::sync_with_stdio(false); \
    cin.tie(0);                  \
    cout.tie(0);
#define DEBUG(...) fprintf(stderr, __VA_ARGS__), fflush(stderr)

using namespace std;
using pi=pair<int,int>;
using repType=int;
using ll=long long;
using ld=long double;
using ull=unsigned long long;

const int MAXN=40005, MAXM=100005;
const int MAXE=MAXN<<1, MLOG=20;

int val[MAXN];
vector<int> G[MAXN];
int n,m;

int stk[MAXN], top=0;
int blk[MAXN], bcnt, bsz;

struct Query
{
    int u, v, id;
    void read(int i)
    {
        id=i;
        scanf("%d%d", &u, &v);
    }
    void adjust()
    {
        if(blk[u]>blk[v]) swap(u,v);
    }
    bool operator < (const Query& rhs) const
    {
        if(blk[u]!=blk[rhs.u]) return blk[u]<blk[rhs.u];
        else return blk[v]<blk[rhs.v]; // Right Range First
    }
} asks[MAXM];
int ans[MAXM];

// Graph
inline void init()
{
    rep(i,1,n) G[i].clear();
}

// Discretization
void get_hash(int a[], int n)
{
    static int tmp[MAXM];
    int cnt=0;
    rep(i,1,n) tmp[cnt++]=a[i];
    sort(tmp,tmp+cnt);
    cnt=unique(tmp,tmp+cnt)-tmp;
    rep(i,1,n) a[i]=lower_bound(tmp,tmp+cnt,a[i])-tmp+1;
}

// Input Read
inline void read_input()
{
    scanf("%d%d", &n, &m);
    rep(i,1,n) scanf("%d", &val[i]);
    get_hash(val, n);
    init();
    rep(i,1,n-1)
    {
        int u,v;
        scanf("%d%d", &u, &v);
        G[u].PB(v);
        G[v].PB(u);
    }
    rep(i,0,m-1) asks[i].read(i);
}

// Find Blks: See BZOJ 1086
inline void add_blk(int& cnt)
{
    while(cnt--) blk[stk[--top]]=bcnt;
    bcnt++;
    cnt=0;
}
inline void rst_blk()
{
    while(top) blk[stk[--top]]=bcnt-1;
}
int dfs_blk(int now, int pre)
{
    int sz=0;
    rep(i,0,int(G[now].size())-1) if(G[now][i]!=pre)
    {
        sz+=dfs_blk(G[now][i], now);
        if(sz>=bsz) add_blk(sz);
    }
    stk[top++]=now;
    sz++;
    if(sz>=bsz) add_blk(sz);
    return sz;
}
inline void init_blk()
{
    bsz=max(1,(int)sqrt(n));
    dfs_blk(1,0);
    rst_blk();
}

// Ask for RMQs: LCA
int fa[MLOG][MAXM], dep[MAXM];

inline void dfs_lca(int u, int f, int ndep)
{
    dep[u]=ndep;
    fa[0][u]=f;
    rep(i,0,int(G[u].size()-1)) if(G[u][i]!=f)
        dfs_lca(G[u][i],u,ndep+1);
}
inline void init_lca()
{
    dfs_lca(1,-1,0);
    rep(k,0,MLOG-2)
    {
        rep(u,1,n)
        {
            if(fa[k][u]==-1) fa[k+1][u]=-1;
            else fa[k+1][u]=fa[k][fa[k][u]];
        }
    }
}
int ask_lca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    rep(k,0,MLOG-1)
        if((dep[u]-dep[v]) & (1<<k)) u=fa[k][u];
    if(u==v) return u;
    per(k,MLOG-1,0)
        if(fa[k][u]!=fa[k][v])
        {
            u=fa[k][u];
            v=fa[k][v];
        }
    return fa[0][u];
}

// Mo's algorithm
bool vis[MAXM];
int diff, cnt[MAXM];

inline void xor_node(int u)
{
    if(vis[u])
    {
        vis[u]=false;
        diff-=(--cnt[val[u]]==0);
    }
    else
    {
        vis[u]=true;
        diff+=(++cnt[val[u]]==1);
    }
}

inline void xor_path_without_lca(int u, int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    while(dep[u]!=dep[v])
    {
        xor_node(u);
        u=fa[0][u];
    }
    while(u!=v)
    {
        xor_node(u);
        u=fa[0][u];
        xor_node(v);
        v=fa[0][v];
    }
}

inline void mv_node(int u, int v, int taru, int tarv)
{
    xor_path_without_lca(u, taru);
    xor_path_without_lca(v, tarv);

    xor_node(ask_lca(u,v));
    xor_node(ask_lca(taru,tarv));
}

inline void make_ans()
{
    rep(i,0,m-1)
        asks[i].adjust(); // make every query has a u,v that u<v
    sort(asks,asks+m);
    int nowu=1,nowv=1; xor_node(1);
    rep(i,0,m-1) // Mo's algorithm -- basis
    {
        mv_node(nowu, nowv, asks[i].u, asks[i].v);
        ans[asks[i].id]=diff;
        nowu=asks[i].u;
        nowv=asks[i].v;
    }
}

inline void print_ans()
{
    rep(i,0,m-1) printf("%d\n", ans[i]);
}

int main()
{
    read_input();
    init_blk();
    init_lca();
    make_ans();
    print_ans();

    return 0;
}