1. 程式人生 > >bzoj2243: [SDOI2011]染色(樹鏈剖分)

bzoj2243: [SDOI2011]染色(樹鏈剖分)

bzoj2243

樹鏈剖分好題啊!

題目描述:給定一顆n個點的樹,有m個操作,操作有兩種。

                 1、將節點a到節點b路徑上所有的點都染成顏色c。

                 2、詢問節點a到節點b路徑上的顏色段數量(連續的被認為是同一段)。

 

輸入格式:第一行包含兩個整數n和m,表示節點數和操作個數。

                 第二行n個整數,表示每個節點的初始顏色。

                 接下來n - 1行,每行兩個整數描述一棵樹。

                 接下來m行,每行表示一個操作。

 

輸出格式:對於每一個詢問顏色段數量的操作,輸出一行一個整數,表示顏色段的數量。

 

輸入樣例:

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

 

輸出樣例:

3
1
2

 

解析:很顯然是樹剖,問題是如何線上段樹上維護不同顏色的段數。

          用sum[o]表示一段的不同顏色的段數,lf[o]表示這一段最左邊的顏色,rt[o]表示這一段最右邊的顏色。

          那麼在進行合併時,lf[o] = lf[o << 1],rt[o] = rt[o << 1 | 1]。需要注意的是sum[o]的維護,若左端點的右端顏色等於右端點的左端顏色,則sum[o]要減1。

          即若顏色相同,sum[o] = sum[o << 1] + sum[o << 1 | 1] - 1;若顏色不同,sum[o] = sum[o << 1] + sum[o << 1 | 1]

          另一個需要思考的地方便是如何計算答案,由於在樹上剖鏈時剖出的鏈是不連續的,所以不能單純進行累加。

          這時就要用到 lca 了,可以求出兩個點的lca,分別對兩段進行累加,這樣答案就可以計算了。

          由於在樹剖時是從深度大的往深度小的剖,所以線上段樹中較右的節點會先別訪問到,所以可以記錄一個last,表示上一次剖到的左端點顏色是last,這樣就可以將答案累加。

          有很多細節需要注意,細節可以看程式碼。

 

程式碼如下:

  1 #include<cstdio>
  2 #include<vector>
  3 #include<algorithm>
  4 #include<cstring>
  5 #define lc o << 1
  6 #define rc o << 1 | 1
  7 using namespace std;
  8 
  9 const int maxn = 1e5 + 5;
 10 int n, m, col[maxn], bj[maxn * 4], sum[maxn * 4], lf[maxn * 4], rt[maxn * 4], ans, last;
 11 int dep[maxn], fa[maxn], size[maxn], heavy[maxn], seq[maxn], dfn[maxn], top[maxn], cnt;
 12 char s[5];
 13 vector <int> ve[maxn];
 14 
 15 int read(void) {
 16     char c; while (c = getchar(), c < '0' || c >'9'); int x = c - '0';
 17     while (c = getchar(), c >= '0' && c <= '9') x = x * 10 + c - '0'; return x;
 18 }
 19 
 20 void dfs1(int u, int pre) {
 21     dep[u] = dep[pre] + 1;
 22     fa[u] = pre; size[u] = 1;
 23       for (int i = 0; i < ve[u].size(); ++ i) {
 24           int v = ve[u][i];
 25             if (v == pre) continue;
 26           dfs1(v, u);
 27           size[u] += size[v];
 28             if (size[v] > size[heavy[u]]) heavy[u] = v;
 29       }
 30 }
 31 
 32 void dfs2(int u, int cur) {
 33     dfn[u] = ++ cnt; seq[cnt] = u;
 34     top[u] = cur;
 35       if (!heavy[u]) return;
 36     dfs2(heavy[u], cur);
 37       for (int i = 0; i < ve[u].size(); ++ i) {
 38           int v = ve[u][i];
 39             if (v == fa[u] || v == heavy[u]) continue;
 40           dfs2(v, v);
 41       }
 42 }
 43 
 44 void maintain(int o) { //維護每段的資訊 
 45     lf[o] = lf[lc]; rt[o] = rt[rc];
 46     if (rt[lc] == lf[rc]) sum[o] = sum[lc] + sum[rc] - 1;
 47       else sum[o] = sum[lc] + sum[rc];
 48 }
 49 
 50 void pushdown(int o) { //標記下放 
 51     sum[lc] = sum[rc] = 1;
 52     lf[lc] = lf[rc] = rt[lc] = rt[rc] = bj[o];
 53     bj[lc] = bj[rc] = bj[o]; bj[o] = -1; 
 54 }
 55 
 56 void build(int o, int l, int r) { //建樹 
 57     if (l == r) {
 58       lf[o] = col[seq[l]];
 59       rt[o] = col[seq[l]];
 60       sum[o] = 1; 
 61       return;
 62     }
 63     int mid = l + r >> 1;
 64     build(lc, l, mid); build(rc, mid + 1, r);
 65     maintain(o);
 66 }
 67 
 68 void modify(int o, int l, int r, int ql, int qr, int c) { //區間修改 
 69     if (ql <= l && qr >= r) {
 70       lf[o] = rt[o] = c; 
 71       sum[o] = 1; bj[o] = c;
 72       return;
 73     }
 74     int mid = l + r >> 1;
 75     if (bj[o] != -1) pushdown(o);
 76     if (ql <= mid) modify(lc, l, mid, ql, qr, c);
 77     if (qr > mid) modify(rc, mid + 1, r, ql, qr, c);
 78     maintain(o);
 79 }
 80 
 81 void query(int o, int l, int r, int ql, int qr) {
 82     if (ql <= l && qr >= r) {
 83       if (rt[o] == last) ans += sum[o] - 1; //如果右端的顏色和上一個左端相同,就-1 
 84         else ans += sum[o];  
 85       last = lf[o]; //更新last表示的左端點 
 86       return;
 87     }
 88     int mid = l + r >> 1;
 89     if (bj[o] != -1) pushdown(o);
 90     if (qr > mid) query(rc, mid + 1, r, ql, qr); //由於是從右向左更新答案,所以線段樹上詢問時也要優先向右詢問! 
 91     if (ql <= mid) query(lc, l, mid, ql, qr);
 92 }
 93 
 94 void chain_modify(int x, int y, int c) { //樹上修改 
 95     int fax = top[x], fay = top[y];
 96       while (fax != fay) {
 97           if (dep[fax] < dep[fay]) {
 98               swap(fax, fay);
 99               swap(x, y);
100           }
101         modify(1, 1, n, dfn[fax], dfn[x], c);
102         x = fa[fax];
103         fax = top[x];
104       }
105       if (dep[x] > dep[y]) swap(x, y);
106     modify(1, 1, n, dfn[x], dfn[y], c);
107 }
108  
109 void chain_query(int x, int y) { //樹上詢問 
110     int fax = top[x], fay = top[y];
111       while (fax != fay) {
112           if (dep[fax] < dep[fay]) {
113               swap(fax, fay);
114               swap(x, y);
115           }
116         query(1, 1, n, dfn[fax], dfn[x]);
117         x = fa[fax];
118         fax = top[x];
119       }
120       if (dep[x] > dep[y]) swap(x, y);
121     query(1, 1, n, dfn[x], dfn[y]);
122 }
123 
124 int getlca(int x, int y) { //求lca 
125     int fax = top[x], fay = top[y];
126       while (fax != fay) {
127           if (dep[fax] < dep[fay]) {
128               swap(fax, fay);
129               swap(x, y);
130           }
131         x = fa[fax];
132         fax = top[x];
133       }
134     if (dep[x] > dep[y]) swap(x, y);
135     return x;
136 }
137 
138 int main() {
139     n = read(); m = read();
140       for (int i = 1; i <= n; ++ i) col[i] = read();
141       for (int i = 1; i < n; ++ i) {
142           int x = read(), y = read();
143           ve[x].push_back(y);
144           ve[y].push_back(x);
145       }
146     dfs1(1, 0);
147     dfs2(1, 1);
148     build(1, 1, n);
149     memset(bj, -1, sizeof(bj)); //顏色可以為0!所以初始標記是-1 
150       while (m --) {
151           scanf("%s", s + 1); 
152             if (s[1] == 'C') {
153                   int x = read(), y = read(), c = read();
154                   int lca = getlca(x, y);
155                   chain_modify(x, lca, c); chain_modify(lca, y, c);
156             }
157           else { //要求兩次答案,並累加答案 
158               int x = read(), y = read(); ans = 0; last = -1;
159               int lca = getlca(x, y);
160               chain_query(x, lca);
161               last = -1; 
162               chain_query(lca, y);
163               printf("%d\n", ans - 1); //這裡ans必須-1,因為lca處的顏色必定相同 
164           }
165       }
166     return 0;
167 }