洛谷-P2486 染色
染色
树链剖分
考虑如果在数列上的话,就是用线段树处理这个问题
线段树记录答案,并且处理区间和并的问题:如果区间合并的地方颜色相同,则加和后的答案要减一
因此维护所有线段树区间两端的颜色
染色的过程可以加入 \(lazytag\)
然后再在树上跑一个树链剖分
时间复杂度为 \(O(nlog^2n)\)
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 1e5 + 10;
int fa[maxn], hson[maxn], siz[maxn], dep[maxn];
int top[maxn], dfn[maxn], rnk[maxn], tp = 0;
int col[maxn], lcol[maxn << 2], rcol[maxn << 2];
int tag[maxn << 2], tr[maxn << 2];
vector<int>gra[maxn];
void dfs1(int now, int pre, int d)
{
dep[now] = d;
hson[now] = -1;
siz[now] = 1;
fa[now] = pre;
for(int nex : gra[now])
{
if(nex == pre) continue;
dfs1(nex, now, d + 1);
siz[now] += siz[nex];
if(hson[now] == -1 || siz[hson[now]] < siz[nex])
hson[now] = nex;
}
}
void dfs2(int now, int t)
{
tp++;
dfn[now] = tp;
rnk[tp] = now;
top[now] = t;
if(hson[now] != -1)
{
dfs2(hson[now], t);
for(int nex : gra[now])
{
if(nex == hson[now] || nex == fa[now]) continue;
dfs2(nex, nex);
}
}
}
inline void push_up(int now)
{
int lson = now << 1, rson = now << 1 | 1;
tr[now] = tr[lson] + tr[rson];
if(rcol[lson] == lcol[rson]) tr[now]--;
}
void build(int now, int l, int r)
{
lcol[now] = col[rnk[l]];
rcol[now] = col[rnk[r]];
if(l == r)
{
tr[now] = 1;
return;
}
int mid = l + r >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
push_up(now);
}
inline void push_down(int now)
{
if(tag[now] == 0) return;
int lson = now << 1, rson = now << 1 | 1;
tr[lson] = tr[rson] = 1;
lcol[lson] = rcol[lson] = lcol[rson] = rcol[rson] = tag[lson] = tag[rson] = tag[now];
tag[now] = 0;
}
void update(int now, int l, int r, int L, int R, int val)
{
if(R >= r) rcol[now] = val;
if(L <= l) lcol[now] = val;
if(L <= l && r <= R)
{
tr[now] = 1;
tag[now] = val;
return;
}
push_down(now);
int mid = l + r >> 1;
if(L <= mid) update(now << 1, l, mid, L, R, val);
if(R > mid) update(now << 1 | 1, mid + 1, r, L, R, val);
push_up(now);
}
int query_num(int now, int l, int r, int L, int R)
{
if(L <= l && r <= R)
return tr[now];
int mid = l + r >> 1;
push_down(now);
if(L <= mid && R > mid)
{
int ans = 0;
ans += query_num(now << 1, l, mid, L, R);
ans += query_num(now << 1 | 1, mid + 1, r, L, R);
if(rcol[now << 1] == lcol[now << 1 | 1]) ans--;
return ans;
}
if(L <= mid) return query_num(now << 1, l, mid, L, R);
if(R > mid) return query_num(now << 1 | 1, mid + 1, r, L, R);
}
int query_col(int now, int l, int r, int way)
{
if(l == r && l == way) return lcol[now];
int mid = l + r >> 1;
push_down(now);
int ans = 0;
if(way <= mid) ans = query_col(now << 1, l, mid, way);
else ans = query_col(now << 1 | 1, mid + 1, r, way);
push_up(now);
return ans;
}
void init(int n)
{
dfs1(1, 1, 1);
dfs2(1, 1);
build(1, 1, n);
}
int query_p(int u, int v, int n)
{
int ans = 0;
while(top[u] != top[v])
{
if(dep[top[u]] < dep[top[v]]) swap(u, v);
ans += query_num(1, 1, n, dfn[top[u]], dfn[u]);
if(top[u] != 1 && query_col(1, 1, n, dfn[top[u]]) == query_col(1, 1, n, dfn[fa[top[u]]])) ans--;
u = fa[top[u]];
}
if(dfn[u] > dfn[v]) swap(u, v);
ans += query_num(1, 1, n, dfn[u], dfn[v]);
return ans;
}
void update_p(int u, int v, int x, int n)
{
while(top[u] != top[v])
{
if(dep[top[u]] < dep[top[v]]) swap(u, v);
update(1, 1, n, dfn[top[u]], dfn[u], x);
u = fa[top[u]];
}
if(dfn[u] > dfn[v]) swap(u, v);
update(1, 1, n, dfn[u], dfn[v], x);
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for(int i=1; i<=n; i++) scanf("%d", &col[i]);
for(int i=1; i<n; i++)
{
int a, b;
scanf("%d%d", &a, &b);
gra[a].push_back(b);
gra[b].push_back(a);
}
init(n);
char op[10];
while(m--)
{
scanf("%s", op);
if(op[0] == 'Q')
{
int a, b;
scanf("%d%d", &a, &b);
printf("%d\n", query_p(a, b, n));
}
else
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
update_p(a, b, c, n);
}
}
return 0;
}