【演算法微解讀】淺談線段樹
淺談線段樹
(來自TRTTG大佬的供圖)
線段樹個人理解和運用時,認為這個是一個比較實用的優化演算法。
這個東西和區間樹有點相似,是一棵二叉搜尋樹,也就是查詢節點和節點所帶值的一種演算法。
使用線段樹可以快速的查詢某一個節點在若干條線段中出現的次數,時間複雜度為O(logN),這個時間複雜度非常的理想,但是空間複雜度在應用時是開4N的。
所以這個演算法有優勢,也有劣勢。
我們提出一個問題
如果當前有一個區間,需要你在給定區間內做以下操作:
- l,z 在l上加上z
- l 查詢l的值
- l,r,z 在[l,r]區間所有數都+z
- l,r, 查詢l到r之間的和
你是不是在想,暴力解決一切問題,但是如果給你的資料是極大的,暴力完全做不了。
那麼我們就需要使用線段樹了。
我們就以這個問題為例來對線段樹進行講解。
先提供一下這個題目的AC程式碼
#include <bits/stdc++.h> using namespace std; const int maxn=10010; struct segment_tree{ int l,r,sum,lazy; }tree[maxn<<2]; int a[maxn]; int n,m; void pushup(int nod) { tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum; } void pushdown(int nod,int l,int r) { int mid=(l+r)>>1; tree[nod<<1].sum+=(mid-l+1)*tree[nod].lazy; tree[(nod<<1)+1].sum+=(r-mid)*tree[nod].lazy; tree[nod<<1].lazy+=tree[nod].lazy; tree[(nod<<1)+1].lazy+=tree[nod].lazy; tree[nod].lazy=0; } void build(int l,int r,int nod) { if (l==r) { tree[nod].sum=a[l]; tree[nod].l=l; tree[nod].r=r; tree[nod].lazy=0; return; } int mid=(l+r)>>1; build(l,mid,nod<<1); build(mid+1,r,(nod<<1)+1); pushup(nod); } void update1(int l,int r,int k,int value,int nod) { if (l==r) { tree[nod].sum+=value; return ; } int mid=(l+r)>>1; pushdown(nod,l,r); if (k<=mid) update1(l,mid,k,value,nod<<1); else update1(mid+1,r,k,value,(nod<<1)+1); pushup(nod); } int query1(int l,int r,int nod,int k) { if (l==r) return tree[nod].sum; int mid=(l+r)>>1; pushdown(nod,l,r); if (k<=mid) return query1(l,mid,nod<<1,k); else return query1(mid+1,r,(nod<<1)+1,k); } void update2(int l,int r,int ll,int rr,int nod,int value) { if (l==ll&&r==rr) { tree[nod].sum+=(r-l+1)*value; tree[nod].lazy+=value; return; } pushdown(nod,l,r); int mid=(l+r)>>1; if (rr<=mid) update2(l,mid,ll,rr,nod<<1,value); else if (l>mid) update2(mid+1,r,ll,rr,(nod<<1)+1,value); else { update2(l,mid,ll,mid,nod<<1,value); update2(mid+1,r,mid+1,rr,(nod<<1)+1,value); } pushup(nod); } int query2(int l,int r,int ll,int rr,int nod) { if (l==ll&r==rr) { return tree[nod].sum; } pushdown(nod,l,r); int mid=(l+r)>>1; if (rr<=mid) return query2(l,mid,ll,rr,nod<<1); else if (ll>mid) return query2(mid+1,r,ll,rr,(nod<<1)+1); else return query2(l,mid,ll,mid,nod<<1)+query2(mid+1,r,mid+1,rr,(nod<<1)+1); } int main() { scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&a[i]); build(1,n,1); while (m--) { int c,x,y,z; scanf("%d",&c); if (c==1) { scanf("%d%d",&x,&y); update1(1,n,x,y,1); } if (c==2) { scanf("%d",&x); printf("%d\n",query1(1,n,1,x)); } if (c==3) { scanf("%d%d%d",&x,&y,&z); update2(1,n,x,y,1,z); } if (c==4) { scanf("%d%d",&x,&y); printf("%d\n",query2(1,n,x,y,1)); } } return 0; }
線段樹的一些基本操作
- 建樹
- 單點修改
- 單點查詢
- 區間修改
- 區間查詢
- pushup(兒子把資訊傳給父親)
- pushdown(父親把資訊傳給兒子)
(其他的應該都是這些基本操作的變形)
以下我們來逐一講解一下
結構體
作為一課非常正經的樹,我們還是要給它開一個結構體。
struct segment_tree{ int l,r,sum; }tree[maxn];
關於線段樹的一些小提醒
我們寫線段樹,應該先知道當前節點nod的左右兒子的編號是多少,答案是(nod 2)和(nod 2+1)
為什麼?我們寫的線段樹應該是一棵滿二叉樹,所以根據滿二叉樹節點的特點,我們就可以知道了他的兒子就是以上的答案。
建樹
由於是二叉搜尋樹,也就是一個二叉樹,需要做搜尋操作。那麼我們就是以樹狀結構來儲存資料。
我們來了解一下線段樹:
我們設當前的線段樹的節點是 \[ tree.l\ tree.r \] ,也就是當前這段區間的左右l和r。(其實我們在寫程式碼的時候一般是不寫這個l和r的)
其次我們還需要當前節點 \[ tree.sum \] ,表示當前節點所帶的值。
在後面我們會講到 \[ tree.lazy \] ,表示當前節點的懶標記,來方便我們進行區間修改的一個東西,我們現在先不講
線段樹的基本思想:二分。

void build(int l,int r,int nod) { if (l==r) { tree[nod].sum=a[l]; tree[nod].l=l; tree[nod].r=r; return; } int mid=(l+r)>>1; build(l,mid,nod<<1); build(mid+1,r,(nod<<1)+1); pushup(nod); }
有人在問這個pushup是什麼東西?
pushup
pushup就是把兒子的資訊上傳給自己的父親節點
以當前問題為例,那麼這個pushup的過程就是以下程式
void pushup(int nod) { tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum; }
其實也就是把和上傳給父親,非常簡單,其他的pushup都是這個道理
單點修改

我們單點修改只需要直接在原節點上修改就可以了。
那麼我們廢話不多說,直接上程式碼更好理解
void update(int l,int r,int k,int value,int nod){ if(l==r) { tree[nod].sum+=value; return; } int mid=(l+r)/2; if(k<=mid)update(l,mid,k,value,nod*2); else update(mid+1,r,k,value,nod*2+1); pushup(nod); return; }
這段程式也就是左右查詢當前節點,k是我們需要尋找的節點,如果在左區間,那麼就在左區間查詢,有區間也是這個意思。
單點查詢
方法與二分查詢基本一致,如果當前列舉的點左右端點相等,即葉子節點,就是目標節點。如果不是,因為這是二分法,所以設查詢位置為x,當前結點區間範圍為了l,r,中點為mid,則如果x<=mid,則遞迴它的左孩子,否則遞迴它的右孩子。
直接上程式碼
int query(int l,int r,int ll,int rr,int nod){ if(l==ll&&r==rr)return tree[nod].sum; int mid=(l+r)/2; if(rr<=mid)return query(l,mid,ll,rr,nod*2); else if(ll>mid)return query(mid+1,r,ll,rr,nod*2+1); else return query(l,mid,ll,mid,nod*2)+query(mid+1,r,mid+1,rr,nod*2+1); }
非常的簡單我們就不多說了
區間修改
我們思考一個問題,如果我們只是像單點修改那樣子,用一個迴圈語句,把要修改區間內的所有點都進行單點修改,那麼這個的複雜度應該是O(NlogN),那麼這就無法發揮出線段樹的優勢了。
那麼我們應該怎麼做呢?
這個時候我們就需要引入一個叫做懶標記的東西。
顧名思義,這個就是一個非常懶的標記,這個就是在我們要的區間內的節點上所加的標記,這個標記也就只有我們要對父親區間內的數進行修改或者附其他值的時候才會用到的一個東西。

這個標記比較難理解,所以我們稍微講的詳細一點?
首先如果要對一個區間內的節點進行修改,那麼就只需要在所需的區間內進行修改,也就只是放在那裡,讓他不要動。
當你要對接下來的區間內的數進行詢問時,我們就需要進行pushdown的操作,這個操作就是要把父親的懶標記上所擁有的全部資訊全部給自己的兒子。
再傳給兒子後,我們的父親就要刪除自己的懶標記,因為自己的懶標記已經傳給了自己的兒子了,為了不產生錯誤,我們就要刪除父親的懶標記。
還是與我們這個例題為例,我們的區間修改的應該是這樣寫的:
void update2(int l,int r,int ll,int rr,int nod,int value) { if (l==ll&&r==rr) { tree[nod].sum+=(r-l+1)*value; tree[nod].lazy+=value; return; } pushdown(nod,l,r); int mid=(l+r)>>1; if (rr<=mid) update2(l,mid,ll,rr,nod<<1,value); else if (l>mid) update2(mid+1,r,ll,rr,(nod<<1)+1,value); else { update2(l,mid,ll,mid,nod<<1,value); update2(mid+1,r,mid+1,rr,(nod<<1)+1,value); } pushup(nod); }
我們再回到這個問題,為什麼會有這麼多的if語句,我們現在來講解一下
ll,rr是需要修改的區間。
當你的區間的rr也就是最右邊在mid的左邊,那麼說明我們整個區間就在l和mid之間,就是以下的情況
好了右區間也是一樣,其他的情況就是當前的區間分佈在mid的左右,那麼就分成兩部分修改就可以了
那麼最後因為兒子可能被改變了,所以我們就要pushup一下。
小提醒
如果你實在不知道什麼時候要pushup或者是pushdown,那麼多多益善,這樣只是會增高你的時間複雜度,而不會影響正確率。
pushdown
這個操作在上文已經講過是把父親的lazy下傳給兒子的過程。
直接上程式碼
void pushdown(int nod,int l,int r) { int mid=(l+r)>>1; tree[nod<<1].sum+=(mid-l+1)*tree[nod].lazy; tree[(nod<<1)+1].sum+=(r-mid)*tree[nod].lazy; tree[nod<<1].lazy+=tree[nod].lazy; tree[(nod<<1)+1].lazy+=tree[nod].lazy; tree[nod].lazy=0; }
區間查詢


這個道理和區間修改差不多,還更簡單一點。
也不多講了,直接上程式碼
int query2(int l,int r,int ll,int rr,int nod) { if (l==ll&r==rr) { return tree[nod].sum; } pushdown(nod,l,r); int mid=(l+r)>>1; if (rr<=mid) return query2(l,mid,ll,rr,nod<<1); else if (ll>mid) return query2(mid+1,r,ll,rr,(nod<<1)+1); else return query2(l,mid,ll,mid,nod<<1)+query2(mid+1,r,mid+1,rr,(nod<<1)+1); }
一些模板題
ofollow,noindex" target="_blank">codevs線段樹練習
#include<bits/stdc++.h> using namespace std; const int N=100000; int tree[N*4+10],s[N]; void build(int l,int r,int nod) { if(l==r){tree[nod]=s[l];return;} int mid=(l+r)/2; build(l,mid,2*nod); build(mid+1,r,nod*2+1); tree[nod]=tree[nod*2]+tree[nod*2+1]; return; } void update(int l,int r,int k,int value,int nod){ if(l==r){tree[nod]+=value;return;} int mid=(l+r)/2; if(k<=mid)update(l,mid,k,value,nod*2); else update(mid+1,r,k,value,nod*2+1); tree[nod]=tree[nod*2]+tree[nod*2+1]; return; } int query(int l,int r,int ll,int rr,int nod){ if(l==ll&&r==rr)return tree[nod]; int mid=(l+r)/2; if(rr<=mid)return query(l,mid,ll,rr,nod*2); else if(ll>mid)return query(mid+1,r,ll,rr,nod*2+1); else return query(l,mid,ll,mid,nod*2)+query(mid+1,r,mid+1,rr,nod*2+1); } int main() { int n,m; scanf("%d",&n); for(int i=1;i<=n;i++)scanf("%d",&s[i]); build(1,n,1); scanf("%d",&m); while(m--){ int x,y,z; scanf("%d%d%d",&x,&y,&z); if(x==1)update(1,n,y,z,1); else printf("%d\n",query(1,n,y,z,1)); } return 0; }
codevs線段樹練習2
#include<bits/stdc++.h> using namespace std; const int N=1000000; int tree[N*4+10],a[N]; void update(int nod,int l,int r,int ll,int rr,int value){ if(l==ll&&r==rr){tree[nod]+=value;return;} int mid=(l+r)/2; if(rr<=mid)update(2*nod,l,mid,ll,rr,value); else if(ll>mid)update(nod*2+1,mid+1,r,ll,rr,value); else{ update(2*nod,l,mid,ll,mid,value); update(2*nod+1,mid+1,r,mid+1,rr,value); } return; } void pushdown(int nod){ tree[nod*2+1]+=tree[nod]; tree[nod*2]+=tree[nod]; tree[nod]=0; return; } int query(int nod,int l,int r,int k){ if(l==r)return a[l]+tree[nod]; int mid=(l+r)/2; pushdown(nod); if(k<=mid)return query(2*nod,l,mid,k); else return query(2*nod+1,mid+1,r,k); } int main() { int n,m; scanf("%d",&n); for(int i=1;i<=n;i++)scanf("%d",&a[i]); scanf("%d",&m); while(m--){ int x,y,z,k; scanf("%d",&x); if(x==1){ scanf("%d%d%d",&y,&z,&k); update(1,1,n,y,z,k); } else{ scanf("%d",&y); printf("%d\n",query(1,1,n,y)); } } return 0; }
codevs線段樹練習4
#include<bits/stdc++.h> using namespace std; const int N(200000); struct node{ long long sum,add; }tree[4*N+10]; int a[N+10]; inline void pushdown(long long nod,long long l,long long r){ long long mid((l+r)>>1); tree[nod<<1].sum+=(mid-l+1)*tree[nod].add; tree[(nod<<1)+1].sum+=(r-mid)*tree[nod].add; tree[nod<<1].add+=tree[nod].add; tree[(nod<<1)+1].add+=tree[nod].add; tree[nod].add=0; return; } inline long long read(){ long long x(0); char ch=getchar(); while(ch<'0'||ch>'9')ch=getchar(); while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x; } void pushup(long long nod){ tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum; return; } void build(long long l,long long r,long long nod){ tree[nod].add=0; if(l==r){ tree[nod].sum=a[l]; return; } long long mid((l+r)>>1); build(l,mid,nod<<1); build(mid+1,r,(nod<<1)+1); tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum; return; } void update(long long l,long long r,long long ll,long long rr,long long value,long long nod){ if(l==ll&&r==rr){ tree[nod].sum+=(r-l+1)*value; tree[nod].add+=value; return; } pushdown(nod,l,r); long long mid((l+r)>>1); if(rr<=mid)update(l,mid,ll,rr,value,nod<<1); else if(ll>mid)update(mid+1,r,ll,rr,value,(nod<<1)+1); else{ update(l,mid,ll,mid,value,nod<<1); update(mid+1,r,mid+1,rr,value,(nod<<1)+1); } pushup(nod); return; } long long query(long long l,long long r,long long ll,long long rr,long long nod){ if(l==ll&&r==rr)return tree[nod].sum; pushdown(nod,l,r); long long mid=(l+r)>>1; if(rr<=mid)return query(l,mid,ll,rr,nod<<1); else if(ll>mid)return query(mid+1,r,ll,rr,(nod<<1)+1); else return query(l,mid,ll,mid,nod*2)+query(mid+1,r,mid+1,rr,(nod<<1)+1); } int main() { long long m; register long long n; m=read(); for(long long i=1;i<=m;++i)a[i]=read(); build(1,m,1); n=read(); while(n--){ long long t,x,y,z; t=read(); if(t==1){ x=read(); y=read(); z=read(); update(1,m,x,y,z,1); } else{ x=read(); y=read(); printf("%lld\n",query(1,m,x,y,1)); } } return 0; }
codevs線段樹練習4
#include<cstdio> #include<cstring> #include<algorithm> #include<cmath> using namespace std; const int N=1000000; int add[N],sum[N*4+10][7],a[N]; inline int read(){ int x(0); char ch=getchar(); while(ch<'0'||ch>'9')ch=getchar(); while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x; } void pushup(int nod){ for(int i=0;i<7;i++) sum[nod][i]=sum[nod<<1][i]+sum[(nod<<1)+1][i]; return; } void build(int l,int r,int nod){ if(l==r){ sum[nod][a[l]%7]++; return; } int mid((l+r)>>1); build(l,mid,nod<<1); build(mid+1,r,(nod<<1)+1); pushup(nod); return; } void modify(int nod,int v){ int t[7]; for(int i=0;i<7;i++) t[(i+v)%7]=sum[nod][i]; for(int i=0;i<7;i++) sum[nod][i]=t[i]; add[nod]=(add[nod]+v)%7; return; } void pushdown(int nod){ modify(nod<<1,add[nod]); modify((nod<<1)+1,add[nod]); add[nod]=0; return; } int query(int l,int r,int ll,int rr,int nod){ if(l==ll&&r==rr) return sum[nod][0]; int mid((l+r)>>1); pushdown(nod); if(rr<=mid)query(l,mid,ll,rr,nod<<1); else if(ll>mid)query(mid+1,r,ll,rr,(nod<<1)+1); else return query(l,mid,ll,mid,nod<<1)+query(mid+1,r,mid+1,rr,(nod<<1)+1); } void update(int l,int r,int ll,int rr,int value,int nod){ if(l==ll&&r==rr){ modify(nod,value); return; } int mid((l+r)>>1); pushdown(nod); if(rr<=mid)update(l,mid,ll,rr,value,nod<<1); else if(ll>mid)update(mid+1,r,ll,rr,value,(nod<<1)+1); else{ update(l,mid,ll,mid,value,nod<<1); update(mid+1,r,mid+1,rr,value,(nod<<1)+1); } pushup(nod); return; } int main() { int n; n=read(); for(int i=1;i<=n;i++)scanf("%d",&a[i]); build(1,n,1); int q; q=read(); while(q--){ char s[10]; scanf("%s",s); if(s[0]=='c'){ int x,y; x=read(); y=read(); printf("%d\n",query(1,n,x,y,1)); } else{ int x,y,z; x=read(); y=read(); z=read(); update(1,n,x,y,z,1); } } return 0; }