(树链剖分+区间合并)HYSBZ - 2243 染色

题意:

两个操作:

1、把一条树链上的所有点权值变为w。

2、查询一条树链上有多少个颜色段

 

分析:

一看就是区间合并,做这到题首先需要一定的区间合并基础,

不过这题合并这部分在线段树区间合并中已经算是非常的简单的了。

线段树部分没有难度。

那么难点在于,在往LCA上走的时候,我们如何进行区间合并。

本来我想着, 在向上走的时候顺便进行区间判断并且合并,但是似乎有问题。

其实,可以将两步分开,先算出区间没合并之前的颜色段数,再次进行Top,判断颜色是否相等,相等就减掉。

 

代码:

  1 #include <math.h>
  2 #include <stdio.h>
  3 #include <stdlib.h>
  4 #include <string.h>
  5 #include <time.h>
  6 #include <algorithm>
  7 #include <iostream>
  8 #include <map>
  9 #include <queue>
 10 #include <set>
 11 #include <string>
 12 #include <vector>
 13 using namespace std;
 14 
 15 const int maxn = 1000000;
 16 const int inf = 0x3f3f3f3f;
 17 
 18 struct Edge {
 19     int to, next;
 20 } edge[maxn << 1];
 21 
 22 int head[maxn], tot;
 23 int top[maxn];
 24 int fa[maxn];
 25 int deep[maxn];
 26 int num[maxn];
 27 int p[maxn];
 28 int fp[maxn];
 29 int son[maxn];
 30 int pos;
 31 
 32 int val[maxn];
 33 
 34 void init() {
 35     tot = 0;
 36     memset(head, -1, sizeof head);
 37     pos = 0;
 38     memset(son, -1, sizeof son);
 39 }
 40 
 41 void addedge(int u, int v) {
 42     edge[tot].to = v;
 43     edge[tot].next = head[u];
 44     head[u] = tot++;
 45 }
 46 void dfs1(int u, int pre, int d) {
 47     deep[u] = d;
 48     fa[u] = pre;
 49     num[u] = 1;
 50     for (int i = head[u]; i != -1; i = edge[i].next) {
 51         int v = edge[i].to;
 52         if (v != pre) {
 53             dfs1(v, u, d + 1);
 54             num[u] += num[v];
 55             if (son[u] == -1 || num[v] > num[son[u]]) son[u] = v;
 56         }
 57     }
 58 }
 59 
 60 void getpos(int u, int sp) {
 61     top[u] = sp;
 62     p[u] = pos++;
 63     fp[p[u]] = u;
 64     if (son[u] == -1) return;
 65     getpos(son[u], sp);
 66     for (int i = head[u]; i != -1; i = edge[i].next) {
 67         int v = edge[i].to;
 68         if (v != son[u] && v != fa[u]) getpos(v, v);
 69     }
 70 }
 71 
 72 struct Node {
 73     int left, right;
 74     int cnt, lcol, rcol;
 75     int lazy;
 76 } node[maxn << 2];
 77 
 78 void build(int n, int left, int right) {
 79     node[n].left = left;
 80     node[n].right = right;
 81     node[n].cnt = node[n].lcol = node[n].rcol = 0;
 82     node[n].lazy = -1;
 83     if (left == right) return;
 84     int mid = (left + right) >> 1;
 85     build(n << 1, left, mid);
 86     build(n << 1 | 1, mid + 1, right);
 87 }
 88 
 89 void push_up(int n) {
 90     node[n].lcol = node[n << 1].lcol;
 91     node[n].rcol = node[n << 1 | 1].rcol;
 92     node[n].cnt = node[n << 1].cnt + node[n << 1 | 1].cnt;
 93     if (node[n << 1].rcol == node[n << 1 | 1].lcol) node[n].cnt--;
 94 }
 95 
 96 void push_down(int n) {
 97     if (node[n].lazy != -1) {
 98         node[n << 1].cnt = 1;
 99         node[n << 1].lcol = node[n << 1].rcol = node[n].lazy;
100         node[n << 1].lazy = node[n].lazy;
101         node[n << 1 | 1].cnt = 1;
102         node[n << 1 | 1].lcol = node[n << 1 | 1].rcol = node[n].lazy;
103         node[n << 1 | 1].lazy = node[n].lazy;
104         node[n].lazy = -1;
105     }
106 }
107 
108 void update(int n, int left, int right, int val) {
109     if (left <= node[n].left && node[n].right <= right) {
110         node[n].cnt = 1;
111         node[n].lcol = node[n].rcol = val;
112         node[n].lazy = val;
113         return;
114     }
115     push_down(n);
116     int mid = (node[n].left + node[n].right) >> 1;
117     if (mid >= left) update(n << 1, left, right, val);
118     if (mid < right) update(n << 1 | 1, left, right, val);
119     push_up(n);
120 }
121 
122 int query_cnt(int n, int left, int right) {
123     if (left <= node[n].left && node[n].right <= right) {
124         return node[n].cnt;
125     }
126     push_down(n);
127     int mid = (node[n].left + node[n].right) >> 1;
128     if (mid >= right)
129         return query_cnt(n << 1, left, right);
130     else if (mid < left)
131         return query_cnt(n << 1 | 1, left, right);
132     else {
133         int lcnt = query_cnt(n << 1, left, right);
134         int rcnt = query_cnt(n << 1 | 1, left, right);
135         int cnt = lcnt + rcnt;
136         if (node[n << 1].rcol == node[n << 1 | 1].lcol) cnt--;
137         push_up(n);
138         return cnt;
139     }
140 }
141 
142 int query_col(int n, int pos) {
143     if (node[n].left == node[n].right) {
144         return node[n].lcol;
145     }
146     push_down(n);
147     int mid = (node[n].left + node[n].right) >> 1;
148     if (pos <= mid)
149         return query_col(n << 1, pos);
150     else
151         return query_col(n << 1 | 1, pos);
152 }
153 
154 int findCnt(int x, int y) {
155     int u = x, v = y;
156     int tmp = 0;
157     int precol = -1;
158     while (top[x] != top[y]) {
159         if (deep[top[x]] < deep[top[y]]) swap(x, y);
160         tmp += query_cnt(1, p[top[x]], p[x]);
161         x = fa[top[x]];
162     }
163     if (deep[x] > deep[y]) swap(x, y);
164     tmp += query_cnt(1, p[x], p[y]);
165     // if (top[u] == top[v]) return tmp;
166     while (top[u] != top[x]) {
167         int col1 = query_col(1, p[top[u]]);
168         int col2 = query_col(1, p[fa[top[u]]]);
169         if (col1 == col2) tmp--;
170         u = fa[top[u]];
171     }
172     while (top[v] != top[x]) {
173         int col1 = query_col(1, p[top[v]]);
174         int col2 = query_col(1, p[fa[top[v]]]);
175         if (col1 == col2) tmp--;
176         v = fa[top[v]];
177     }
178     return tmp;
179 }
180 
181 void Change(int x, int y, int val) {
182     while (top[x] != top[y]) {
183         if (deep[top[x]] < deep[top[y]]) swap(x, y);
184         update(1, p[top[x]], p[x], val);
185         x = fa[top[x]];
186     }
187     if (deep[x] > deep[y]) swap(x, y);
188     update(1, p[x], p[y], val);
189 }
190 
191 int main() {
192     int t;
193     int n;
194     int q;
195     while (~scanf("%d%d", &n, &q)) {
196         init();
197         for (int i = 1; i <= n; i++) {
198             scanf("%d", &val[i]);
199         }
200         for (int i = 0; i < n - 1; i++) {
201             int u, v;
202             scanf("%d%d", &u, &v);
203             addedge(u, v);
204             addedge(v, u);
205         }
206 
207         dfs1(1, 0, 0);
208         getpos(1, 1);
209         build(1, 0, pos - 1);
210         for (int i = 1; i <= n; i++) {
211             update(1, p[i], p[i], val[i]);
212         }
213         scanf("%d", &q);
214         char op[10];
215         int u, v;
216         while (q--) {
217             scanf("%s%d%d", op, &u, &v);
218             if (op[0] == 'Q') {
219                 printf("%d\n", findCnt(u, v));
220             } else {
221                 int val;
222                 scanf("%d", &val);
223                 Change(u, v, val);
224             }
225         }
226     }
227     return 0;
228 }

 

posted @ 2017-07-17 22:40  tak_fate  阅读(203)  评论(0编辑  收藏  举报