1. 程式人生 > >【資料結構】【可持久化資料結構--線段樹】

【資料結構】【可持久化資料結構--線段樹】

可持久化資料結構就是讓某種資料結構可以持久的訪問以前的歷史版本,同時也可以從某個歷史版本再建一個新的版本的資料結構,比如可持久化線段樹,並查集,treap。
正常的線段樹長這個樣子。

但是如果我對這棵樹進行一些奇奇怪怪的操作那麼我們就無法查詢以前的版本了,比如說我們要查詢區間第k大,那麼我們可以通過查詢樹找到第k大的數,但是對於區間第k大就不行了,所以我們要把每個歷史版本都存下了,很明顯,直接存空間是會炸的,但是我們可以利用空間,說以我們可以把不同歷史版本的相同部分利用起來,於是各個歷史版本之間就會有一些共用的節點。比如我把上面這顆線段樹先整成一個空樹,再加一個2,就變成了下面這個圖
這裡寫圖片描述


在這裡我們新加了一個樹根,因此我們要把每一個歷史版本的數根用root[MAXN]存下來,由於在這個問題上我們只需要把2的所有父節點建出來,因此插入第一個元素後,土紅色是新生成的結點,它們是cnt都是1,其餘的藍色結點的cnt仍是0。每插入一個元素,僅新增log(n)個結點,耗時log(n),因此建樹的時間、空間均為nlog(n)
那麼我們要查詢區間就是從兩個歷史版本的root開始相減。就可以像正常的一樣查找了。
這個是程式碼。

#include<cstdio>
#include<cstring>
#include<algorithm>
using
namespace std; const int MAXN=100000; struct node { int tot; int ch[2]; }; node tree[MAXN*20]; int ncnt,root[MAXN+5]; int a[MAXN+5],b[MAXN+5]; int n,m,bn; void init() { ncnt=0; } void insert(int &cur,int val,int l,int r) { ncnt++; tree[ncnt]=tree[cur]; cur=ncnt; tree[cur].tot++; if
(l==r) return ; int mid=(l+r)>>1; if(val<=mid) insert(tree[cur].ch[0],val,l,mid); else insert(tree[cur].ch[1],val,mid+1,r); } void prepare() { scanf("%d %d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d",&a[i]); b[i]=a[i]; } sort(b+1,b+n+1); bn=unique(b+1,b+n+1)-b-1; root[0]=0; for(int i=1;i<=n;i++) { int val=lower_bound(b+1,b+n+1,a[i])-b; root[i]=root[i-1]; insert(root[i],val,1,bn); } } int find(int x,int y,int k,int l,int r) { if(l==r) return l; int mid=(l+r)>>1; int t=tree[tree[x].ch[0]].tot-tree[tree[y].ch[0]].tot; if(k<=t) return find(tree[x].ch[0],tree[y].ch[0],k,l,mid); else return find(tree[x].ch[1],tree[y].ch[1],k-t,mid+1,r); } void solve() { int l,r,k; for(int i=1;i<=m;i++) { scanf("%d %d %d",&l,&r,&k); printf("%d\n",b[find(root[r],root[l-1],k,1,bn)]); } } int main() { init(); prepare(); solve(); }

不過這是不帶修改的,帶修改的由於要影響以前的n個版本,如果暴力更新,時間就會直接炸掉,所以我們用樹套樹的的辦法來解決,即再加一個樹狀陣列來維護序列的字首和。
但是怎麼樹套樹呢?這裡有一個很簡單的理解方法.
我們先看一下樹狀樹組 .
這裡寫圖片描述
這個圖很簡單的顯示了對應的陣列下標的對應區間,所以我們的第i個歷史版本的線段樹不再是記錄字首和,而是記錄對應的區間(如上圖),這樣每次更改和查詢都是O(logn*logn) 的。
這是帶修改的。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int MAXN=500005,MAXQ=100005;
struct node
{
    int tot;
    int ch[2];
};
node tree[MAXN*20];
int root[MAXN];
int ncnt;
int a[MAXN],b[MAXN*2];
int n,m,bn;
int cntL,cntR,L[MAXN],R[MAXN];
int pos[MAXN];
struct cmd
{
    int l,r,k;
    char op;
}query[MAXQ];
void SegUpdate(int &cur,int l,int r,int pos,int d)
{
    if(!cur)
    {
        ncnt++;
        tree[ncnt]=tree[cur];
        cur=ncnt;
    }
    tree[cur].tot+=d;
    if(l==r)
        return ;
    int mid=(l+r)>>1;
    if(pos<=mid)
        SegUpdate(tree[cur].ch[0],l,mid,pos,d);
    else
        SegUpdate(tree[cur].ch[1],mid+1,r,pos,d);
}
void BitUpdate(int x,int pos,int d)
{
    while(x<=n)
    {
        SegUpdate(root[x],1,bn,pos,d);
        x+=x&-x;
    }
}
int GetPos(int x)
{
    return lower_bound(b+1,b+bn+1,x)-b;
}
void prepare()
{
    char s[20];
    scanf("%d %d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        b[i]=a[i];
    }
    bn=n;
    for(int i=1;i<=m;i++)
    {
        scanf("%s",s);
        query[i].op=s[0];
        if(s[0]=='Q')
            scanf("%d %d %d",&query[i].l,&query[i].r,&query[i].k);
        else
        {
            scanf("%d %d",&query[i].l,&query[i].k);
            b[++bn]=query[i].k;
        }
    }
    sort(b+1,b+bn+1);
    bn=unique(b+1,b+bn+1)-b-1;
    root[0]=0;
    for(int i=1;i<=n;i++)
    {
        pos[i]=GetPos(a[i]);
        BitUpdate(i,pos[i],1);
    }
}
int SegKth(int l,int r,int k)
{
    if(l==r)
        return l;
    int mid=(l+r)>>1;
    int tl=0,tr=0,sum;
    for(int i=1;i<=cntL;i++)
        tl+=tree[tree[L[i]].ch[0]].tot;
    for(int i=1;i<=cntR;i++)
        tr+=tree[tree[R[i]].ch[0]].tot;
    sum=tr-tl;
    if(k<=sum)
    {
        for(int i=1;i<=cntL;i++)
        L[i]=tree[L[i]].ch[0];
        for(int i=1;i<=cntR;i++)
        R[i]=tree[R[i]].ch[0];
        return SegKth(l,mid,k);
    }
    else
    {
        for(int i=1;i<=cntL;i++)
        L[i]=tree[L[i]].ch[1];
        for(int i=1;i<=cntR;i++)
        R[i]=tree[R[i]].ch[1];
        return SegKth(mid+1,r,k-sum);
    }
}
int BitKth(int st,int ed,int k)
{
    cntL=cntR=0;
    while(st>0)
    {
        L[++cntL]=root[st];
        st-=st&-st;
    }
    while(ed>0)
    {
        R[++cntR]=root[ed];
        ed-=ed&-ed;
    }
    int pos=SegKth(1,bn,k);
    return b[pos];
}

void solve(){
    int st=0,ed=0,k=0,ans=0;
    for(int i=1;i<=m;i++)
    {
        if(query[i].op=='Q')
        {
            st=query[i].l;
            ed=query[i].r;
            k=query[i].k;
            ans=BitKth(st-1,ed,k);
            printf("%d\n",ans);
        }
        else
        {
            st=query[i].l,k=query[i].k;
            BitUpdate(st,pos[st],-1);
            pos[st]=GetPos(k);
            BitUpdate(st,pos[st],1);
        }
    }
}
int main()
{
    ncnt=0;
    prepare();
    solve();
}