1. 程式人生 > >luogu P3391 【模板】文藝平衡樹(Splay)

luogu P3391 【模板】文藝平衡樹(Splay)

嘟嘟嘟


突然覺得splay挺有意思的……


這道題只有一個任務:區間翻轉。
首先應該知道的是,splay和線段樹一樣,都可以打標記,然後走到每一個節點之前先下傳。
那怎麼打標記呢?還應該有“區間”的思想。
對於區間\([L, R]\),想辦法把這個區間所在的子樹提取出來,然後打個標記即可。
那怎麼提取呢?其實也不難。只要找出\(L\)的前驅\(a = L - 1\)\(R\)的後繼\(b = R + 1\),然後把\(a\)旋到根,再把\(b\)旋到根的右子節點,這樣\(b\)的左子樹就是當前區間了。
但是找前驅和後繼只能像bst那麼找,因為這棵splay的key值是下標,而下標並沒有存起來,而是通過子樹大小體現的。所以上述找前驅和後繼操作相當於查詢第\(k\)

大。因為事先加了\(-INF\)\(INF\)防止越界,所以找前驅就是查詢第\(L\)大的,後繼就是第\(R + 2\)大的。

int getRank(int k)
{
  int now = root;
  while(1)
    {
      pushdown(now);
      if(t[t[now].ch[0]].siz >= k) now = t[now].ch[0];
      else if(t[t[now].ch[0]].siz + 1 == k) return now;
      else k -= t[t[now].ch[0]].siz + 1, now = t[now].ch[1];
    }
}
void update(int L, int R)
{
  int a = getRank(L), b = getRank(R + 2); //pre(L), nxt(R)
  splay(a, 0); splay(b, a); //現在b的左子樹就是當前區間
  pushdown(root); pushdown(t[root].ch[1]);
  int now = t[t[root].ch[1]].ch[0];
  t[now].lzy ^= 1;
}



還有一件事就是建樹,雖然可以像這道題一樣每一次插入一個數,不過有更可愛的方法。
仿照線段樹的建樹方法,但有一個顯著的區別是線段樹的每一個節點表示一個區間,而splay就表示一個點,所以遞迴的時候把當前區間的\(a[mid]\)作為線段樹該節點的權值,然後到\([L, mid - 1]\)\([mid + 1, R]\)中建立左右子樹。

int build(int L, int R, int f)
{
  if(L > R) return 0;
  int mid = (L + R) >> 1, now = ++ncnt;
  t[now].val = a[mid]; t[now].fa = f;
  t[now].ch[0] = build(L, mid - 1, now);
  t[now].ch[1] = build(mid + 1, R, now);
  pushup(now);
  return now;
}



最後一件事就是輸出。利用splay自身的性質,中序遍歷就是答案。


完整程式碼

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define rg register
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, m, a[maxn];
struct Tree
{
  int ch[2], fa;
  int val, siz, lzy;
}t[maxn];
int root, ncnt = 0;
void _PrintTr(int now)
{
  if(!now) return;
  printf("nd:%d val:%d ls:%d rs:%d\n", now, t[now].val, t[t[now].ch[0]].val, t[t[now].ch[1]].val);
  _PrintTr(t[now].ch[0]); _PrintTr(t[now].ch[1]);
}
void pushdown(int now)
{
  if(now && t[now].lzy)
    {
      t[t[now].ch[0]].lzy ^= 1; t[t[now].ch[1]].lzy ^= 1;
      swap(t[now].ch[0], t[now].ch[1]);
      t[now].lzy = 0;
    }
}
void pushup(int now)
{
  t[now].siz = t[t[now].ch[0]].siz + t[t[now].ch[1]].siz + 1;
}
void rotate(int x)
{
  int y = t[x].fa, z = t[y].fa, k = (t[y].ch[1] == x);
  t[z].ch[t[z].ch[1] == y] = x; t[x].fa = z;
  t[y].ch[k] = t[x].ch[k ^ 1]; t[t[y].ch[k]].fa = y;
  t[x].ch[k ^ 1] = y; t[y].fa = x;
  pushup(y); pushup(x);
}
void splay(int x, int s)  //旋轉的時候不用pushdown.(因為是自底向上的)
{
  while(t[x].fa != s)
    {
      int y = t[x].fa, z = t[y].fa;
      if(z != s)
    {
      if((t[z].ch[0] == y) ^ (t[y].ch[0] == x)) rotate(x);
      else rotate(y);
    }
      rotate(x);
    }
  if(s == 0) root = x;
}
int build(int L, int R, int f)
{
  if(L > R) return 0;
  int mid = (L + R) >> 1, now = ++ncnt;
  t[now].val = a[mid]; t[now].fa = f;
  t[now].ch[0] = build(L, mid - 1, now);
  t[now].ch[1] = build(mid + 1, R, now);
  pushup(now);
  return now;
}
int getRank(int k)
{
  int now = root;
  while(1)
    {
      pushdown(now);
      if(t[t[now].ch[0]].siz >= k) now = t[now].ch[0];
      else if(t[t[now].ch[0]].siz + 1 == k) return now;
      else k -= t[t[now].ch[0]].siz + 1, now = t[now].ch[1];
    }
}
void update(int L, int R)
{
  int a = getRank(L), b = getRank(R + 2); //pre(L), nxt(R)
  splay(a, 0); splay(b, a); //現在b的左子樹就是當前區間
  pushdown(root); pushdown(t[root].ch[1]);
  int now = t[t[root].ch[1]].ch[0];
  t[now].lzy ^= 1;
}
void print(int now)
{
  pushdown(now);
  if(t[now].ch[0]) print(t[now].ch[0]);
  if(t[now].val != INF && t[now].val != -INF) write(t[now].val), space;
  if(t[now].ch[1]) print(t[now].ch[1]);
}

int main()
{
  n = read(); m = read();
  a[1] = -INF; a[n + 2] = INF;
  for(int i = 1; i <= n; ++i) a[i + 1] = i;
  root = build(1, n + 2, 0);
  //_PrintTr(root);
  for(int i = 1, L, R; i <= m; ++i) L = read(), R = read(), update(L, R);
  print(root), enter;
  return 0;
}