線段樹模板【洛谷P2023】
阿新 • • 發佈:2018-11-21
傳送門:https://www.luogu.org/problemnew/show/P2023
這個題目的區間更新有加法和乘法。
所以比裸的線段樹難一點點吧,也就僅僅是一點點。
既然存在兩個操作,所以我們就要維護兩個tag,一個加法一個乘法。
但是pushdown的時候這兩個tag怎麼pushdown呢?
乘法的優先順序顯然比加法高,所以我們在mul更新的時候要先pushdown,這是一個要點。第二個要點就是,Pushdown的時候,對於乘法tag,我們可以直接乘上父節點的tag,但是對於加法的,我們要怎麼辦呢? 我們要先把該結點的tag乘上乘法tag,然後再加上父節點的tag。為什麼要這樣做呢?因為乘法優先順序高的嘛,這樣就完事了。
看不懂的話,下面我就來推導一下
假設父節點是ax+b
我們要pushdown左兒子。
我們要乘一個k然後加上c
就變成了k(ax+b)+c
變成了kax+kb+c
變成了(ka)x+(kb+c)
右邊的kb+c就變成了add[rt<<1]*mul[rt]+add[rt]。ka就是mul[rt<<1]*mul[rt]。
下面是每次都比分塊慢的線段樹程式碼:
(分塊寫這種區間更新的,不如線段樹方便,我就沒寫分塊的程式碼。)
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 1e6+7; ll a[maxn]; ll sum[maxn<<2],add[maxn<<2],mul[maxn<<2]; ll n,p; void pushup(int rt) { sum[rt] = (sum[rt<<1]+sum[rt<<1|1])%p; } void pushdown(int rt,int ln,int rn) { if(add[rt] || mul[rt]!=1) { sum[rt<<1] = (sum[rt<<1]*mul[rt]+add[rt]*ln)%p; sum[rt<<1|1] = (sum[rt<<1|1]*mul[rt]+add[rt]*rn)%p; add[rt<<1] = (add[rt<<1]*mul[rt]+add[rt])%p; add[rt<<1|1] = (add[rt<<1|1]*mul[rt]+add[rt])%p; mul[rt<<1] = (mul[rt<<1]*mul[rt])%p; mul[rt<<1|1] = (mul[rt<<1|1]*mul[rt])%p; add[rt] = 0; mul[rt] = 1; } } void build(int rt,int l,int r) { mul[rt] = 1; if(l==r) { sum[rt] = a[l]%p; return; } int mid = (l+r)/2; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } void Add(int x,int y,int l,int r,int rt,int v) { if(x<=l && y>=r) { sum[rt] = (sum[rt]+(r-l+1)*v)%p; add[rt] = (add[rt]+v)%p; return; } int mid = (l+r)/2; pushdown(rt,mid-l+1,r-mid); if(x<=mid) { Add(x,y,l,mid,rt<<1,v); } if(y>mid) { Add(x,y,mid+1,r,rt<<1|1,v); } pushup(rt); } void Mul(int x,int y,int l,int r,int rt,int v) { int mid = (l+r)/2; pushdown(rt,mid-l+1,r-mid); if(x<=l && y>=r) { sum[rt] = (sum[rt]*v)%p; mul[rt] = (mul[rt]*v)%p; return; } if(x<=mid) { Mul(x,y,l,mid,rt<<1,v); } if(y>mid) { Mul(x,y,mid+1,r,rt<<1|1,v); } pushup(rt); } ll query(int x,int y,int l,int r,int rt) { if(x<=l && y>=r) { return sum[rt]; } int mid = (l+r)/2; pushdown(rt,mid-l+1,r-mid); ll ans = 0; if(x<=mid) { ans = (ans+query(x,y,l,mid,rt<<1))%p; } if(y>mid) { ans = (ans+query(x,y,mid+1,r,rt<<1|1))%p; } return ans; } int main() { scanf("%lld%lld",&n,&p); for(int i=1;i<=n;i++) { scanf("%lld",a+i); } build(1,1,n); int m; scanf("%d",&m); for(int i=0;i<m;i++) { int t,x,y,z; scanf("%d",&t); if(t==1) { scanf("%d%d%d",&x,&y,&z); Mul(x,y,1,n,1,z); } else if(t==2) { scanf("%d%d%d",&x,&y,&z); Add(x,y,1,n,1,z); } else { scanf("%d%d",&x,&y); printf("%lld\n",query(x,y,1,n,1)); } } return 0; }