1. 程式人生 > >線段樹學習(單點更新+區間更新+區間查詢)(C++模板)

線段樹學習(單點更新+區間更新+區間查詢)(C++模板)

一、線段樹的用處

        在對一組連續的資料進行修改或者求和(求最值)操作時,線段樹可以通過快速的修改子區間上的值來達成你的目標。

二、線段樹是什麼

        線段樹是一種二叉搜尋樹,它將一個區間劃分成一些單元區間,每個單元區間對應線段樹中的一個葉結點。使用線段樹可以快速的查詢某一條線段對應的狀態。

        看一副圖來理解(圖片魔改自百度百科):


    可見圖中我們用一個節點1來儲存一段[1,10]線段上的資料,將節點1對半拆開,可以得到節點2和節點3,他們分別儲存的是[1,5]和[6,10]上的資料。以此類推,我們可以分出節點4、節點5……直到不可再分。

三、線段樹的單點更新和查詢(後附區間更新)

    這裡採用的是結構體的方式建立線段樹。

const int maxn = 500005 * 4;	//線段樹範圍要開4倍
struct Tree
{
	int l, r, sum, maxx;
};
Tree node[maxn];	//node[maxn]為線段樹處理使用的陣列
int a[maxn];		//a[maxn]表示讀入的資料	

    結構體中的l,r表示的是該節點所覆蓋的區間為[l,r]。sum表示的是該段區間上的資料總和,maxx表示該段區間上資料的最值。看到之前的那幅圖,我們在將區間進行分割的時候,會出現大量的小區間,所以對於結構體的陣列(即線段樹)大小我們需要設定成條件所給的4倍。

    設定完了節點,那麼我們怎麼來建立一個我們所需要的線段樹呢。再看這副圖。

    

我們該怎麼用程式碼實現連出兩條線使得節點2和節點3成為節點1的子節點呢?

我們可以設當前的節點為i,那麼對於圖上的規律來講,他左邊的子節點編號就是i*2,右邊的子節點編號就是i*2+1

知道了怎麼表示子節點後,我們就要試圖使用他們將葉子節點的值傳遞到下一層。

對於葉子節點來講,他們的值應該是已知的,我們只需要在他們的父節點處進行更新就可以了,這是可以用遞迴實現的。

那麼具體建樹的程式碼就應該是這樣的。

void update(int i)
{
	node[i].sum = node[i << 1].sum + node[(i << 1) | 1].sum;            //求和子節點
	node[i].maxx = max(node[i << 1].maxx, node[(i << 1) | 1].maxx);       //取子節點最值
}

void build(int i, int l, int r)
{
	node[i].l = l; node[i].r = r;
	if (l == r)            //到達了葉子節點直接賦值
	{
		node[i].maxx = a[l];
		node[i].sum = a[l];
		return;
	}
	int mid = (l + r) / 2;
	build(i << 1, l, mid);                //左節點建立
	build((i << 1) | 1, mid + 1, r);        //右節點建立
	update(i);
}

    建立了線段樹後我們要對線段樹上的資料進行修改與求和,該怎麼操作呢?

    對於單點的修改我們找到所需要修改的點k對應的葉子節點,然後一路遞迴更新下去實際上在程式碼和build線段樹是差不多的。

    讓我們直接看程式碼來解釋:

void add(int i, int k, int v)	        //當前更新的節點的編號為i(一般是以1為第一個編號)。
{					//k為需要更新的點的位置,v為修改的值的大小
	if (node[i].l == k&&node[i].r == k)        //左右端點均和k相等,說明找到了k所在的葉子節點
	{
		node[i].sum += v;
		node[i].maxx += v;
		return;    //找到了葉子節點就不需要在向下尋找了
	}
	int mid = (node[i].l + node[i].r) / 2;
	if (k <= mid) add(i << 1, k, v);
	else add((i << 1) | 1, k, v);            //尋找k所在的子區間
	update(i);        //遞迴更新
}

    使用這個add函式就可以實現對線段樹的單點更新啦。比如使k點的值加上v,就是add(1,k,v)。

    求區間的最值程式碼實際上和求和是一樣的,也是先找到對應區間所在的子節點,然後向下遞迴更新。

    求最值程式碼如下

int getmax(int i, int l, int r)
{
	if (node[i].l == l&&node[i].r == r)
		return node[i].maxx;
	int mid = (node[i].l + node[i].r) / 2;
	if (r <= mid) return getmax(i << 1, l, r);
	else if (l>mid) return getmax((i << 1) | 1, l, r);
	else return max(getmax(i << 1, l, mid), getmax((i << 1) | 1, mid + 1, r));
}

    以上。我們已經完成對於一個線段樹的單點更新和查詢。

   模板如下

const int maxn = 500005 * 4;	//線段樹範圍要開4倍
struct Tree
{
	int l, r, sum, maxx;
};
Tree node[maxn];		//node[maxn]為線段樹處理陣列
int a[maxn];			//a[maxn]為原陣列
void update(int i)
{
	node[i].sum = node[i << 1].sum + node[(i << 1) | 1].sum;
	node[i].maxx = max(node[i << 1].maxx, node[(i << 1) | 1].maxx);
}
void build(int i, int l, int r)
{
	node[i].l = l; node[i].r = r;
	if (l == r)
	{
		node[i].maxx = a[l];
		node[i].sum = a[l];
		return;
	}
	int mid = (l + r) / 2;
	build(i << 1, l, mid);
	build((i << 1) | 1, mid + 1, r);
	update(i);
}
int getsum(int i, int l, int r)
{
	if (node[i].l == l&&node[i].r == r)
		return node[i].sum;
	int mid = (node[i].l + node[i].r) / 2;
	if (r <= mid) return getsum(i << 1, l, r);
	else if (l > mid) return getsum((i << 1) | 1, l, r);
	else return getsum(i << 1, l, mid) + getsum((i << 1) | 1, mid + 1, r);
}
int getmax(int i, int l, int r)
{
	if (node[i].l == l&&node[i].r == r)
		return node[i].maxx;
	int mid = (node[i].l + node[i].r) / 2;
	if (r <= mid) return getmax(i << 1, l, r);
	else if (l>mid) return getmax((i << 1) | 1, l, r);
	else return max(getmax(i << 1, l, mid), getmax((i << 1) | 1, mid + 1, r));
}
void add(int i, int k, int v)	        //當前更新的節點的編號為i(一般是1為初始編號,具體得看建立樹時使用的第一個編號是什麼)。
{								//k為需要更新的點的位置,v為修改的值的大小
	if (node[i].l == k&&node[i].r == k)        //左右端點均和k相等,說明找到了k所在的葉子節點
	{
		node[i].sum += v;
		node[i].maxx += v;
		return;    //找到了葉子節點就不需要在向下尋找了
	}
	int mid = (node[i].l + node[i].r) / 2;
	if (k <= mid) add(i << 1, k, v);
	else add((i << 1) | 1, k, v);
	update(i);
}

四、線段樹的區間更新

    為什麼需要把區間更新和單點更新區分開來呢?

    當我們面對給[a,b]範圍上的資料都加上v,這一類的問題時,我們利用單點更新是怎麼操作的呢?

    首先單點更新a+1,再更新a+2,再a+3……直到更新b。那麼對於多個這樣的詢問,顯然運算元是爆表的。所以我們需要一種巧妙的方法,降低我們更新的運算元。

    這裡引入了一個標記陣列,lazy[maxn<<2].

    由字面的意思,這個lazy陣列就是一個給懶人使用標記。

    每當我們需要把一個區間[a,b]都加上v,現在我們其實並沒有直接進入到線段樹的對應區間的子區間去修改,而是先給這個區間做一個標記v,若這個區間有n個數據,當我們查詢時候只需要讀取區間原有的資料並且加上n*v。

    就我的理解lazy標記更像是維護了另一個樹。

    簡單的說就是,我們把向下的修改先儲存起來,而對於每個查詢我們在向上傳遞答案的時候加上這些修改的值。

    用程式碼來實現就是這樣

void PushUp(int rt)
{
	tree[rt].sum = tree[rt << 1].sum + tree[rt << 1 | 1].sum;
}

void PushDown(int rt,int m)        //m表示的是rt對應的當前區間的長度
{
	if (lazy[rt])
	{
		lazy[rt << 1] += lazy[rt];           //延遲的值向左節點傳遞
		lazy[rt << 1 | 1] += lazy[rt];        //延遲的值向右節點傳遞
		tree[rt << 1].sum += lazy[rt] * (m - (m >> 1));   
		tree[rt << 1 | 1].sum += lazy[rt] * (m >> 1);
		lazy[rt] = 0;
	}
}

PushUp函式表示的是向上的更新,PushDown維護的是lazy標記延後的值。

明白了lazy的作用,就可以偷偷放出區間更新的模板啦。

const int N = 100005;
LL a[N];					//a[N]儲存原陣列
LL  lazy[N << 2];			//lazy用來記錄該節點的每個數值應該加多少 
int n, q;
struct Tree
{
	int l, r;
	LL sum;
	int mid()
	{
		return (l + r) >> 1;
	}
}tree[N<<2];		

void PushUp(int rt)
{
	tree[rt].sum = tree[rt << 1].sum + tree[rt << 1 | 1].sum;
}

void PushDown(int rt,int m)
{
	if (lazy[rt])
	{
		lazy[rt << 1] += lazy[rt];
		lazy[rt << 1 | 1] += lazy[rt];
		tree[rt << 1].sum += lazy[rt] * (m - (m >> 1));
		tree[rt << 1 | 1].sum += lazy[rt] * (m >> 1);
		lazy[rt] = 0;
	}
}

void build(int l, int r, int rt)
{
	tree[rt].l = l;
	tree[rt].r = r;
	lazy[rt] = 0;
	if (l == r)
	{
		tree[rt].sum = a[l];
		return;
	}
	int m = tree[rt].mid();
	build(l, m, (rt << 1));
	build(m + 1, r, (rt << 1 | 1));
	PushUp(rt);
}

void update(LL c, int l, int r, int rt)
{
	if (tree[rt].l == l&&tree[rt].r==r)
	{ 
		lazy[rt] += c;
		tree[rt].sum += c*(r - l + 1);
		return;
	}
	if (tree[rt].l == tree[rt].r)return;
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	if (r <= m)update(c, l, r, rt << 1);
	else if (l > m)update(c, l, r, rt << 1 | 1);
	else 
	{
		update(c, l, m, rt << 1);
		update(c, m + 1, r, rt << 1 | 1);
	}
	PushUp(rt);
}

LL Query(int l, int r, int rt)
{
	if (l == tree[rt].l&&r == tree[rt].r)
	{
		return tree[rt].sum;
	}
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	LL res = 0;
	if (r <= m)res += Query(l, r, rt << 1);
	else if (l > m)res += Query(l, r, rt << 1 | 1);
	else
	{
		res += Query(l, m, rt << 1);
		res += Query(m + 1, r, rt << 1 | 1);
	}
	return res;
}

附上線段樹模板題:

以及AC程式碼

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<string>
#include<algorithm>
#include<vector>
#include<queue>
#include<set>
#include<map>
#include<stack>
#include<list>
using namespace std;
const int INF = 0x3f3f3f3f;
#define LL long long int 
long long  gcd(long long  a, long long  b) { return a == 0 ? b : gcd(b % a, a); }



const int N = 100005;
LL a[N];					
LL  lazy[N << 2];			
int n, q;
struct Tree
{
	int l, r;
	LL sum;
	int mid()
	{
		return (l + r) >> 1;
	}
}tree[N<<2];		

void PushUp(int rt)
{
	tree[rt].sum = tree[rt << 1].sum + tree[rt << 1 | 1].sum;
}

void PushDown(int rt,int m)
{
	if (lazy[rt])
	{
		lazy[rt << 1] += lazy[rt];
		lazy[rt << 1 | 1] += lazy[rt];
		tree[rt << 1].sum += lazy[rt] * (m - (m >> 1));
		tree[rt << 1 | 1].sum += lazy[rt] * (m >> 1);
		lazy[rt] = 0;
	}
}

void build(int l, int r, int rt)
{
	tree[rt].l = l;
	tree[rt].r = r;
	lazy[rt] = 0;
	if (l == r)
	{
		tree[rt].sum = a[l];
		return;
	}
	int m = tree[rt].mid();
	build(l, m, (rt << 1));
	build(m + 1, r, (rt << 1 | 1));
	PushUp(rt);
}

void update(LL c, int l, int r, int rt)
{
	if (tree[rt].l == l&&tree[rt].r==r)
	{ 
		lazy[rt] += c;
		tree[rt].sum += c*(r - l + 1);
		return;
	}
	if (tree[rt].l == tree[rt].r)return;
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	if (r <= m)update(c, l, r, rt << 1);
	else if (l > m)update(c, l, r, rt << 1 | 1);
	else 
	{
		update(c, l, m, rt << 1);
		update(c, m + 1, r, rt << 1 | 1);
	}
	PushUp(rt);
}

LL Query(int l, int r, int rt)
{
	if (l == tree[rt].l&&r == tree[rt].r)
	{
		return tree[rt].sum;
	}
	int m = tree[rt].mid();
	PushDown(rt, tree[rt].r - tree[rt].l + 1);
	LL res = 0;
	if (r <= m)res += Query(l, r, rt << 1);
	else if (l > m)res += Query(l, r, rt << 1 | 1);
	else
	{
		res += Query(l, m, rt << 1);
		res += Query(m + 1, r, rt << 1 | 1);
	}
	return res;
}

int main()
{
	while (scanf("%d%d", &n, &q) != EOF)
	{
		for (int i = 1; i <= n; i++)
			scanf("%lld", &a[i]);
		build(1, n, 1);
		char t;
		int a, b;
		LL c;
		while (q--)
		{
			getchar();
			scanf("%c", &t);
			if (t == 'Q')
			{
				scanf("%d %d", &a, &b);
				printf("%lld\n", Query(a, b, 1));
			}
			else if (t == 'C')
			{
				scanf("%d %d %lld", &a, &b, &c);
				update(c, a, b, 1);
			}
		}
	}
	getchar();
	getchar();
}