(树链剖分+区间合并)HYSBZ - 2243 染色
题意:
两个操作:
1、把一条树链上的所有点权值变为w。
2、查询一条树链上有多少个颜色段
分析:
一看就是区间合并,做这到题首先需要一定的区间合并基础,
不过这题合并这部分在线段树区间合并中已经算是非常的简单的了。
线段树部分没有难度。
那么难点在于,在往LCA上走的时候,我们如何进行区间合并。
本来我想着, 在向上走的时候顺便进行区间判断并且合并,但是似乎有问题。
其实,可以将两步分开,先算出区间没合并之前的颜色段数,再次进行Top,判断颜色是否相等,相等就减掉。
代码:
1 #include <math.h> 2 #include <stdio.h> 3 #include <stdlib.h> 4 #include <string.h> 5 #include <time.h> 6 #include <algorithm> 7 #include <iostream> 8 #include <map> 9 #include <queue> 10 #include <set> 11 #include <string> 12 #include <vector> 13 using namespace std; 14 15 const int maxn = 1000000; 16 const int inf = 0x3f3f3f3f; 17 18 struct Edge { 19 int to, next; 20 } edge[maxn << 1]; 21 22 int head[maxn], tot; 23 int top[maxn]; 24 int fa[maxn]; 25 int deep[maxn]; 26 int num[maxn]; 27 int p[maxn]; 28 int fp[maxn]; 29 int son[maxn]; 30 int pos; 31 32 int val[maxn]; 33 34 void init() { 35 tot = 0; 36 memset(head, -1, sizeof head); 37 pos = 0; 38 memset(son, -1, sizeof son); 39 } 40 41 void addedge(int u, int v) { 42 edge[tot].to = v; 43 edge[tot].next = head[u]; 44 head[u] = tot++; 45 } 46 void dfs1(int u, int pre, int d) { 47 deep[u] = d; 48 fa[u] = pre; 49 num[u] = 1; 50 for (int i = head[u]; i != -1; i = edge[i].next) { 51 int v = edge[i].to; 52 if (v != pre) { 53 dfs1(v, u, d + 1); 54 num[u] += num[v]; 55 if (son[u] == -1 || num[v] > num[son[u]]) son[u] = v; 56 } 57 } 58 } 59 60 void getpos(int u, int sp) { 61 top[u] = sp; 62 p[u] = pos++; 63 fp[p[u]] = u; 64 if (son[u] == -1) return; 65 getpos(son[u], sp); 66 for (int i = head[u]; i != -1; i = edge[i].next) { 67 int v = edge[i].to; 68 if (v != son[u] && v != fa[u]) getpos(v, v); 69 } 70 } 71 72 struct Node { 73 int left, right; 74 int cnt, lcol, rcol; 75 int lazy; 76 } node[maxn << 2]; 77 78 void build(int n, int left, int right) { 79 node[n].left = left; 80 node[n].right = right; 81 node[n].cnt = node[n].lcol = node[n].rcol = 0; 82 node[n].lazy = -1; 83 if (left == right) return; 84 int mid = (left + right) >> 1; 85 build(n << 1, left, mid); 86 build(n << 1 | 1, mid + 1, right); 87 } 88 89 void push_up(int n) { 90 node[n].lcol = node[n << 1].lcol; 91 node[n].rcol = node[n << 1 | 1].rcol; 92 node[n].cnt = node[n << 1].cnt + node[n << 1 | 1].cnt; 93 if (node[n << 1].rcol == node[n << 1 | 1].lcol) node[n].cnt--; 94 } 95 96 void push_down(int n) { 97 if (node[n].lazy != -1) { 98 node[n << 1].cnt = 1; 99 node[n << 1].lcol = node[n << 1].rcol = node[n].lazy; 100 node[n << 1].lazy = node[n].lazy; 101 node[n << 1 | 1].cnt = 1; 102 node[n << 1 | 1].lcol = node[n << 1 | 1].rcol = node[n].lazy; 103 node[n << 1 | 1].lazy = node[n].lazy; 104 node[n].lazy = -1; 105 } 106 } 107 108 void update(int n, int left, int right, int val) { 109 if (left <= node[n].left && node[n].right <= right) { 110 node[n].cnt = 1; 111 node[n].lcol = node[n].rcol = val; 112 node[n].lazy = val; 113 return; 114 } 115 push_down(n); 116 int mid = (node[n].left + node[n].right) >> 1; 117 if (mid >= left) update(n << 1, left, right, val); 118 if (mid < right) update(n << 1 | 1, left, right, val); 119 push_up(n); 120 } 121 122 int query_cnt(int n, int left, int right) { 123 if (left <= node[n].left && node[n].right <= right) { 124 return node[n].cnt; 125 } 126 push_down(n); 127 int mid = (node[n].left + node[n].right) >> 1; 128 if (mid >= right) 129 return query_cnt(n << 1, left, right); 130 else if (mid < left) 131 return query_cnt(n << 1 | 1, left, right); 132 else { 133 int lcnt = query_cnt(n << 1, left, right); 134 int rcnt = query_cnt(n << 1 | 1, left, right); 135 int cnt = lcnt + rcnt; 136 if (node[n << 1].rcol == node[n << 1 | 1].lcol) cnt--; 137 push_up(n); 138 return cnt; 139 } 140 } 141 142 int query_col(int n, int pos) { 143 if (node[n].left == node[n].right) { 144 return node[n].lcol; 145 } 146 push_down(n); 147 int mid = (node[n].left + node[n].right) >> 1; 148 if (pos <= mid) 149 return query_col(n << 1, pos); 150 else 151 return query_col(n << 1 | 1, pos); 152 } 153 154 int findCnt(int x, int y) { 155 int u = x, v = y; 156 int tmp = 0; 157 int precol = -1; 158 while (top[x] != top[y]) { 159 if (deep[top[x]] < deep[top[y]]) swap(x, y); 160 tmp += query_cnt(1, p[top[x]], p[x]); 161 x = fa[top[x]]; 162 } 163 if (deep[x] > deep[y]) swap(x, y); 164 tmp += query_cnt(1, p[x], p[y]); 165 // if (top[u] == top[v]) return tmp; 166 while (top[u] != top[x]) { 167 int col1 = query_col(1, p[top[u]]); 168 int col2 = query_col(1, p[fa[top[u]]]); 169 if (col1 == col2) tmp--; 170 u = fa[top[u]]; 171 } 172 while (top[v] != top[x]) { 173 int col1 = query_col(1, p[top[v]]); 174 int col2 = query_col(1, p[fa[top[v]]]); 175 if (col1 == col2) tmp--; 176 v = fa[top[v]]; 177 } 178 return tmp; 179 } 180 181 void Change(int x, int y, int val) { 182 while (top[x] != top[y]) { 183 if (deep[top[x]] < deep[top[y]]) swap(x, y); 184 update(1, p[top[x]], p[x], val); 185 x = fa[top[x]]; 186 } 187 if (deep[x] > deep[y]) swap(x, y); 188 update(1, p[x], p[y], val); 189 } 190 191 int main() { 192 int t; 193 int n; 194 int q; 195 while (~scanf("%d%d", &n, &q)) { 196 init(); 197 for (int i = 1; i <= n; i++) { 198 scanf("%d", &val[i]); 199 } 200 for (int i = 0; i < n - 1; i++) { 201 int u, v; 202 scanf("%d%d", &u, &v); 203 addedge(u, v); 204 addedge(v, u); 205 } 206 207 dfs1(1, 0, 0); 208 getpos(1, 1); 209 build(1, 0, pos - 1); 210 for (int i = 1; i <= n; i++) { 211 update(1, p[i], p[i], val[i]); 212 } 213 scanf("%d", &q); 214 char op[10]; 215 int u, v; 216 while (q--) { 217 scanf("%s%d%d", op, &u, &v); 218 if (op[0] == 'Q') { 219 printf("%d\n", findCnt(u, v)); 220 } else { 221 int val; 222 scanf("%d", &val); 223 Change(u, v, val); 224 } 225 } 226 } 227 return 0; 228 }