1. 程式人生 > >jzoj5943. 【NOIP2018模擬11.01】樹(線段樹)

jzoj5943. 【NOIP2018模擬11.01】樹(線段樹)

5943. 【NOIP2018模擬11.01】樹

Description

在這裡插入圖片描述

Input 第一行一個整數 n 表示序列長度, 接下來一行 n 個整數描述這個序列. 第三行一個整數 q 表示操作次數, 接下來 q 行每行一次操作, 格式同題目描述.

Output 輸出等同於操作 2, 3 次數之和的行數, 每行一個非負整數表示對應詢問的答案. 注意操作 2 的答案不需要進行取模.

Sample Input1 5 8 4 3 5 6 5 2 3 5 3 1 2 1 2 4 3 2 3 5 3 1 2

Sample Output1 14 608 10 384

樣例 1 解釋 第三次操作後, 序列變為 [8, 0, 3, 1, 6].

Data Constraint

對於前 30% 的資料, n, q ≤ 100; 對於另 20% 的資料, 沒有操作 1; 對於另 20% 的資料, 沒有操作 3; 對於 100% 的資料, n, q ≤ 10^5, ai ≤ 10^9, k ≤ 2^30, 1 ≤ l ≤ r ≤ n.

分析:顯然修改只會讓數變小, 每個數只會變小 log 次, 所以我們線段樹維護區間或起來的值判斷是否需要修改, 如果需要就暴力下去修改. 複雜度 O(nlog2n)對於操作 3 直接把式子展開, 再維護一個區間平方和, ∑(a[i]+a[j])^2 = 2(r - l + 1)∑a[i]^2 + (∑a[i])^2。

程式碼

#include <cstdio>
#include <algorithm>
#define N 1000000
#define mo 998244353
#define ll long long
using namespace std;

struct tree
{
	int l, r;
	ll sum, o, sq;
}tr[N];
ll a[N],s,ss;
int n,m;

void build(int p, int l, int r)
{
	tr[p].l = l;
	tr[p].r = r;
	if (l == r)
	{
		tr[p].sum = tr[p].o = a[l];
		tr[p].sq = a[l] * a[l] % mo;
		return;
	}
	int mid = (l + r) / 2;
	build(p * 2, l, mid);
	build(p * 2 + 1, mid + 1, r);
	tr[p].sum = tr[p * 2].sum + tr[p * 2 + 1].sum;
	tr[p].o = tr[p * 2].o | tr[p * 2 + 1].o;
	tr[p].sq = (tr[p * 2].sq + tr[p * 2 + 1].sq) % mo;
}

void down(int p, ll k)
{
	if (tr[p].l == tr[p].r)
	{
		tr[p].sum = tr[p].o = tr[p].sum & k;
		tr[p].sq = tr[p].sum * tr[p].sum % mo;
		return;
	}
	down(p * 2, k);
	down(p * 2 + 1, k);
	tr[p].sum = tr[p * 2].sum + tr[p * 2 + 1].sum;
	tr[p].o = tr[p * 2].o | tr[p * 2 + 1].o;
	tr[p].sq = (tr[p * 2].sq + tr[p * 2 + 1].sq) % mo;
}

void change(int p, int l, int r, ll k)
{
	if (tr[p].l == l && tr[p].r == r)
	{
		if ((tr[p].o & k) != tr[p].o) down(p, k);
		return;
	}
	int mid = (tr[p].l + tr[p].r) / 2;
	if (r <= mid) change(p * 2, l, r, k);
		else if (l > mid) change(p * 2 + 1, l, r, k);
			else change(p * 2, l, mid, k), change(p * 2 + 1, mid + 1, r, k);
	tr[p].sum = tr[p * 2].sum + tr[p * 2 + 1].sum;
	tr[p].o = tr[p * 2].o | tr[p * 2 + 1].o;
	tr[p].sq = (tr[p * 2].sq + tr[p * 2 + 1].sq) % mo;
}

void find(int p, int l, int r)
{
	if (tr[p].l == l && tr[p].r == r) 
	{
		s += tr[p].sum;
		ss = (ss + tr[p].sq) % mo;
		return;
	}
	int mid = (tr[p].l + tr[p].r) / 2;
	if (r <= mid) find(p * 2, l, r);
		else if (l > mid) find(p * 2 + 1, l, r);
			else find(p * 2, l, mid), find(p * 2 + 1, mid + 1, r);
}

int main()
{
//	freopen("seg.in","r",stdin);
//	freopen("seg.out","w",stdout);
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%lld", &a[i]); 
	build(1, 1, n);
	scanf("%d", &m);
	while (m--)
	{
		int opt, x, y;
		scanf("%d%d%d", &opt, &x, &y);
		if (opt == 1)
		{
			ll k;
			scanf("%lld", &k);
			change(1, x, y, k);
		}
		else
		{
			s = 0;
			ss = 0;
			find(1, x, y);
			if (opt == 2) printf("%lld\n", s);
				else
				{
					s = s % mo;
					s = (2ll * ss % mo * (y - x + 1) % mo + 2ll * s % mo * s) % mo;
					printf("%lld\n", s);
				}
		}
	}
}