1. 程式人生 > >Luogu P3391 【模板】文藝平衡樹(FHQ-Treap)

Luogu P3391 【模板】文藝平衡樹(FHQ-Treap)

題意

給出一個長為$n$序列$[1,2,...,n]$,$m$次操作,每次指定一段區間$[l,r]$,將這段區間翻轉,求最終序列

題解

雖然標題是$Splay$,但是我要用$FHQ\ Treap$,考慮先將$[l,r]$這段區間$split$出來($k$即為這段區間)

void split(int o, int k, int &l, int &r) {
    if(!o) { l = r = 0; return ; }
    if(siz[lc[o]] < k) l = o, split(rc[o], k - siz[lc[o]] - 1, rc[o], r);
    else r = o, split(lc[o], k, l, lc[o]);
    upt(o);
}//注意這裡要按size來split

//寫在main函式中
while(m--) {
    read(x), read(y);
    split(rt, y, l, r), split(l, x - 1, l, k);
    rev[k] ^= 1; rt = merge(merge(l, k), r);
}//x,y為操作的區間

然後再將這段區間打一個翻轉標記(因為平衡樹是可以中序遍歷輸出的吧...,$rev$為翻轉標記)

每次涉及到某個節點時,將$rev$標記下放就好了

void pushdown(int o) {
    std::swap(lc[o], rc[o]);
    if(lc[o]) rev[lc[o]] ^= 1;
    if(rc[o]) rev[rc[o]] ^= 1;
    rev[o] = 0;
}

#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <algorithm>

template<typename T>
void read(T &x) {
    int flag = 1; x = 0; char ch = getchar();
    while(ch < '0' || ch > '9') { if(ch == '-') flag = -flag; ch = getchar(); }
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); x *= flag;
}

const int N = 1e5 + 10;
int n, m, lc[N], rc[N], siz[N], val[N], pri[N], rev[N], tot;

inline void upt(int o) { siz[o] = siz[lc[o]] + siz[rc[o]] + 1; }
inline int node(int x) { val[++tot] = x, pri[tot] = rand(), siz[tot] = 1; return tot;}
void pushdown(int o) {
    std::swap(lc[o], rc[o]);
    if(lc[o]) rev[lc[o]] ^= 1;
    if(rc[o]) rev[rc[o]] ^= 1;
    rev[o] = 0;
}
void split(int o, int k, int &l, int &r) {
    if(!o) { l = r = 0; return ; }
    if(rev[o]) pushdown(o);
    if(siz[lc[o]] < k) l = o, split(rc[o], k - siz[lc[o]] - 1, rc[o], r);
    else r = o, split(lc[o], k, l, lc[o]);
    upt(o);
}
int merge(int l, int r) {
    if(!l || !r) return l + r;
    if(pri[l] < pri[r]) { if(rev[l]) pushdown(l); rc[l] = merge(rc[l], r), upt(l); return l; }
    else { if(rev[r]) pushdown(r); lc[r] = merge(l, lc[r]), upt(r); return r; }
}
void print(int o) {
    if(!o) return ;
    if(rev[o]) pushdown(o);
    print(lc[o]), printf("%d ", val[o]), print(rc[o]);
}

int main () {
    read(n), read(m), srand((unsigned)time(NULL));
    int x, y, l, r, k, rt = 0;
    for(int i = 1; i <= n; ++i) rt = merge(rt, node(i));
    while(m--) {
        read(x), read(y);
        split(rt, y, l, r), split(l, x - 1, l, k);
        rev[k] ^= 1; rt = merge(merge(l, k), r);
    } print(rt);
    return 0;
}