jzoj5943. 【NOIP2018模擬11.01】樹(線段樹)
阿新 • • 發佈:2018-12-19
5943. 【NOIP2018模擬11.01】樹
Description
Input 第一行一個整數 n 表示序列長度, 接下來一行 n 個整數描述這個序列. 第三行一個整數 q 表示操作次數, 接下來 q 行每行一次操作, 格式同題目描述.
Output 輸出等同於操作 2, 3 次數之和的行數, 每行一個非負整數表示對應詢問的答案. 注意操作 2 的答案不需要進行取模.
Sample Input1 5 8 4 3 5 6 5 2 3 5 3 1 2 1 2 4 3 2 3 5 3 1 2
Sample Output1 14 608 10 384
樣例 1 解釋 第三次操作後, 序列變為 [8, 0, 3, 1, 6].
Data Constraint
對於前 30% 的資料, n, q ≤ 100; 對於另 20% 的資料, 沒有操作 1; 對於另 20% 的資料, 沒有操作 3; 對於 100% 的資料, n, q ≤ 10^5, ai ≤ 10^9, k ≤ 2^30, 1 ≤ l ≤ r ≤ n.
分析:顯然修改只會讓數變小, 每個數只會變小 log 次, 所以我們線段樹維護區間或起來的值判斷是否需要修改, 如果需要就暴力下去修改. 複雜度 O(nlog2n)對於操作 3 直接把式子展開, 再維護一個區間平方和, ∑(a[i]+a[j])^2 = 2(r - l + 1)∑a[i]^2 + (∑a[i])^2。
程式碼
#include <cstdio> #include <algorithm> #define N 1000000 #define mo 998244353 #define ll long long using namespace std; struct tree { int l, r; ll sum, o, sq; }tr[N]; ll a[N],s,ss; int n,m; void build(int p, int l, int r) { tr[p].l = l; tr[p].r = r; if (l == r) { tr[p].sum = tr[p].o = a[l]; tr[p].sq = a[l] * a[l] % mo; return; } int mid = (l + r) / 2; build(p * 2, l, mid); build(p * 2 + 1, mid + 1, r); tr[p].sum = tr[p * 2].sum + tr[p * 2 + 1].sum; tr[p].o = tr[p * 2].o | tr[p * 2 + 1].o; tr[p].sq = (tr[p * 2].sq + tr[p * 2 + 1].sq) % mo; } void down(int p, ll k) { if (tr[p].l == tr[p].r) { tr[p].sum = tr[p].o = tr[p].sum & k; tr[p].sq = tr[p].sum * tr[p].sum % mo; return; } down(p * 2, k); down(p * 2 + 1, k); tr[p].sum = tr[p * 2].sum + tr[p * 2 + 1].sum; tr[p].o = tr[p * 2].o | tr[p * 2 + 1].o; tr[p].sq = (tr[p * 2].sq + tr[p * 2 + 1].sq) % mo; } void change(int p, int l, int r, ll k) { if (tr[p].l == l && tr[p].r == r) { if ((tr[p].o & k) != tr[p].o) down(p, k); return; } int mid = (tr[p].l + tr[p].r) / 2; if (r <= mid) change(p * 2, l, r, k); else if (l > mid) change(p * 2 + 1, l, r, k); else change(p * 2, l, mid, k), change(p * 2 + 1, mid + 1, r, k); tr[p].sum = tr[p * 2].sum + tr[p * 2 + 1].sum; tr[p].o = tr[p * 2].o | tr[p * 2 + 1].o; tr[p].sq = (tr[p * 2].sq + tr[p * 2 + 1].sq) % mo; } void find(int p, int l, int r) { if (tr[p].l == l && tr[p].r == r) { s += tr[p].sum; ss = (ss + tr[p].sq) % mo; return; } int mid = (tr[p].l + tr[p].r) / 2; if (r <= mid) find(p * 2, l, r); else if (l > mid) find(p * 2 + 1, l, r); else find(p * 2, l, mid), find(p * 2 + 1, mid + 1, r); } int main() { // freopen("seg.in","r",stdin); // freopen("seg.out","w",stdout); scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%lld", &a[i]); build(1, 1, n); scanf("%d", &m); while (m--) { int opt, x, y; scanf("%d%d%d", &opt, &x, &y); if (opt == 1) { ll k; scanf("%lld", &k); change(1, x, y, k); } else { s = 0; ss = 0; find(1, x, y); if (opt == 2) printf("%lld\n", s); else { s = s % mo; s = (2ll * ss % mo * (y - x + 1) % mo + 2ll * s % mo * s) % mo; printf("%lld\n", s); } } } }