【ybt金牌导航5-1-3】【luogu P2486】树上染色 / 染色
树上染色 / 染色
题目链接:ybt金牌导航5-1-3 / luogu P2486
题目大意
给你一个树,然后点有颜色,要你维护两个操作。
把一条路径上的点的颜色改成某个值,查询一条路径的颜色段数量。
颜色段是指最长的颜色相同的段。
思路
看到树上路径,自然想到树链剖分。
然后你发现你就需要线段树维护区间最左最右的颜色,这个很好维护。
然后你也可以维护出区间颜色段的。
然后记得树链剖分的时候你将链合并时它们也可能缝合,那你每次要把这条链两端颜色找出来,然后拿去比较。
然后大概就是这样。
代码
#include<cstdio>
#include<algorithm>
using namespace std;
struct node {
int to, nxt;
}e[200001];
int n, m, a[100001], x, y, z;
int le[100001], KK, lcol, rcol;
int fa[100001], son[100001], sz[100001], deg[100001];
int pl[100001], dfn[100001], top[100001], tmp;
char op;
struct Tree {
int lc, rc, num, lazy;
}tr[400001];
void add(int x, int y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
e[++KK] = (node){x, le[y]}; le[y] = KK;
}
void dfs1(int now, int father) {//树链剖分
fa[now] = father;
deg[now] = deg[father] + 1;
sz[now] = 1;
int maxn = 0;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
dfs1(e[i].to, now);
sz[now] += sz[e[i].to];
if (sz[e[i].to] > maxn) {
maxn = sz[e[i].to];
son[now] = e[i].to;
}
}
}
void dfs2(int now, int father) {
if (son[now]) {
pl[son[now]] = ++tmp;
top[son[now]] = top[now];
dfn[tmp] = son[now];
dfs2(son[now], now);
}
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && e[i].to != son[now]) {
pl[e[i].to] = ++tmp;
top[e[i].to] = e[i].to;
dfn[tmp] = e[i].to;
dfs2(e[i].to, now);
}
}
//线段树操作
void up(int now) {
tr[now].lc = tr[now << 1].lc;
tr[now].rc = tr[now << 1 | 1].rc;
tr[now].num = tr[now << 1].num + tr[now << 1 | 1].num;
if (tr[now << 1].rc == tr[now << 1 | 1].lc)//接合的地方下相同颜色,两颜色段合并,所以个数少了一个
tr[now].num--;//下面的很多地方减一都是处理这个
}
void down(int now, int l, int r) {
if (tr[now].lazy != -1 && l != r) {
tr[now << 1].num = tr[now << 1 | 1].num = 1;
tr[now << 1].lc = tr[now << 1].rc = tr[now].lazy;
tr[now << 1 | 1].lc = tr[now << 1 | 1].rc = tr[now].lazy;
tr[now << 1].lazy = tr[now << 1 | 1].lazy = tr[now].lazy;
tr[now].lazy = -1;
}
}
void build(int now, int l, int r) {
if (l == r) {
tr[now] = (Tree){a[dfn[l]], a[dfn[l]], 1, -1};
return ;
}
tr[now].lazy = -1;
int mid = (l + r) >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
up(now);
}
void change(int now, int l, int r, int L, int R, int num) {
if (L <= l && r <= R) {
tr[now].lc = tr[now].rc = num;
tr[now].num = 1;
tr[now].lazy = num;
return ;
}
down(now, l, r);
int mid = (l + r) >> 1;
if (L <= mid) change(now << 1, l, mid, L, R, num);
if (mid < R) change(now << 1 | 1, mid + 1, r, L, R, num);
up(now);
}
int query(int now, int l, int r, int L, int R, int &l_col, int &r_col) {
if (L <= l && r <= R) {
if (l == L) l_col = tr[now].lc;//找到这段区间最左最右的颜色
if (r == R) r_col = tr[now].rc;
return tr[now].num;
}
down(now, l, r);
int mid = (l + r) >> 1, re = 0;
if (L <= mid) re += query(now << 1, l, mid, L, R, l_col, r_col);
if (mid < R) re += query(now << 1 | 1, mid + 1, r, L, R, l_col, r_col);
if (L <= mid && mid < R) {
if (tr[now << 1].rc == tr[now << 1 | 1].lc) re--;//相同颜色缝合
}
return re;
}
void color(int x, int y, int z) {
int X = top[x], Y = top[y];
while (X != Y) {
if (deg[X] < deg[Y]) {
swap(X, Y);
swap(x, y);
}
change(1, 1, tmp, pl[X], pl[x], z);
x = fa[X];
X = top[x];
}
if (deg[x] > deg[y]) swap(x, y);
change(1, 1, tmp, pl[x], pl[y], z);
}
int ask_num(int x, int y) {
int X = top[x], Y = top[y], re = 0, tmp1, tmp2;
lcol = rcol = -2;
while (X != Y) {
if (deg[X] < deg[Y]) {
swap(X, Y);
swap(x, y);
swap(lcol, rcol);
}
tmp1 = tmp2 = -4;
re += query(1, 1, tmp, pl[X], pl[x], tmp1, tmp2);
if (tmp2 == lcol) re--;//跟你前面走的链缝合
lcol = tmp1;
x = fa[X];
X = top[x];
}
if (deg[x] > deg[y]) swap(x, y), swap(lcol, rcol);
tmp1 = tmp2 = -4;
re += query(1, 1, tmp, pl[x], pl[y], tmp1, tmp2);
if (tmp1 == lcol) re--;//对接处有两个可以缝合的地方
if (tmp2 == rcol) re--;
return re;
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
add(x, y);
}
dfs1(1, 0);
pl[1] = ++tmp; top[1] = 1; dfn[tmp] = 1;
dfs2(1, 0);
build(1, 1, n);
while (m--) {
op = getchar();
while (op != 'C' && op != 'Q') op = getchar();
if (op == 'C') {
scanf("%d %d %d", &x, &y, &z);
color(x, y, z);
continue;
}
if (op == 'Q') {
scanf("%d %d", &x, &y);
printf("%d\n", ask_num(x, y));
continue;
}
}
return 0;
}