BZOJ 2243 染色(树链剖分好题)
2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 7971 Solved: 2990
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
题目链接:BZOJ 2243
做了几道普通的树链剖分维护边权、点权,查询路径的题目,感觉并没有什么特点,然而这题比较有意思,求路径上连续颜色有几段,显然用线段树的话只要维护当前区间最左和最右的颜色,左右子区间即可推出父区间的答案:左边段数+右边段数-(左区间右端点颜色==右区间左端点颜色)。然后统计的时候也要利用这个思想——线段树的query与树链剖分中记录u与v上升区间段数的同时也与u、v最后上升的区间最左端点颜色比较得到答案。
代码:
#include <bits/stdc++.h> using namespace std; #define INF 0x3f3f3f3f #define LC(x) (x<<1) #define RC(x) ((x<<1)+1) #define MID(x,y) ((x+y)>>1) #define fin(name) freopen(name,"r",stdin) #define fout(name) freopen(name,"w",stdout) #define CLR(arr,val) memset(arr,val,sizeof(arr)) #define FAST_IO ios::sync_with_stdio(false);cin.tie(0); typedef pair<int, int> pii; typedef long long LL; const double PI = acos(-1.0); const int N = 100010; struct seg { int l, mid, r; int lc, rc; int s, tag; }; struct edge { int to, nxt; edge() {} edge(int _to, int _nxt): to(_to), nxt(_nxt) {} }; edge E[N << 1]; seg T[N << 2]; int head[N], tot; int sz[N], fa[N], son[N], top[N], dep[N], idx[N], ts; int arr[N]; int Rc, Lc; void init() { CLR(head, -1); tot = 0; ts = 0; } void add(int s, int t) { E[tot] = edge(t, head[s]); head[s] = tot++; } void dfs1(int u, int f, int d) { sz[u] = 1; fa[u] = f; son[u] = -1; dep[u] = d; for (int i = head[u]; ~i; i = E[i].nxt) { int v = E[i].to; if (v != f) { dfs1(v, u, d + 1); sz[u] += sz[v]; if (son[u] == -1 || sz[son[u]] < sz[v]) son[u] = v; } } } void dfs2(int u, int tp) { idx[u] = ++ts; top[u] = tp; if (~son[u]) dfs2(son[u], tp); for (int i = head[u]; ~i; i = E[i].nxt) { int v = E[i].to; if (v != fa[u] && v != son[u]) dfs2(v, v); } } void pushup(int k) { T[k].s = T[LC(k)].s + T[RC(k)].s - (T[LC(k)].rc == T[RC(k)].lc); T[k].lc = T[LC(k)].lc; T[k].rc = T[RC(k)].rc; } void pushdown(int k) { if (T[k].tag == -1) return ; T[LC(k)].tag = T[RC(k)].tag = T[k].tag; T[LC(k)].lc = T[LC(k)].rc = T[k].tag; T[RC(k)].lc = T[RC(k)].rc = T[k].tag; T[LC(k)].s = T[RC(k)].s = 1; T[k].tag = -1; } void build(int k, int l, int r) { T[k].l = l; T[k].r = r; T[k].mid = MID(l, r); T[k].lc = T[k].rc = 0; T[k].tag = -1; T[k].s = 0; if (l == r) return ; build(LC(k), l, T[k].mid); build(RC(k), T[k].mid + 1, r); } void update(int k, int l, int r, int c) { if (l <= T[k].l && T[k].r <= r) { T[k].tag = c; T[k].lc = T[k].rc = c; T[k].s = 1; } else { pushdown(k); if (r <= T[k].mid) update(LC(k), l, r, c); else if (l > T[k].mid) update(RC(k), l, r, c); else { update(LC(k), l, T[k].mid, c); update(RC(k), T[k].mid + 1, r, c); } pushup(k); } } int query(int k, int l, int r, int L, int R) { if (L == T[k].l) Lc = T[k].lc; if (R == T[k].r) Rc = T[k].rc; if (l <= T[k].l && T[k].r <= r) return T[k].s; else { pushdown(k); if (r <= T[k].mid) return query(LC(k), l, r, L, R); else if (l > T[k].mid) return query(RC(k), l, r, L, R); else return query(LC(k), l, T[k].mid, L, R) + query(RC(k), T[k].mid + 1, r, L, R) - (T[LC(k)].rc == T[RC(k)].lc); } } int Find(int u, int v) { int ret = 0; int tu = top[u], tv = top[v]; int last_u = -1, last_v = -1; while (tu != tv) { if (dep[tu] < dep[tv]) { swap(tu, tv); swap(u, v); swap(last_u, last_v); } ret += query(1, idx[tu], idx[u], idx[tu], idx[u]); if (Rc == last_u) --ret; last_u = Lc; u = fa[tu]; tu = top[u]; } if (dep[u] > dep[v]) { swap(u, v); swap(last_u, last_v); } ret += query(1, idx[u], idx[v], idx[u], idx[v]); if (Lc == last_u) --ret; if (Rc == last_v) --ret; return ret; } void solve(int u, int v, int c) { int tu = top[u], tv = top[v]; while (tu != tv) { if (dep[tu] < dep[tv]) { swap(tu, tv); swap(u, v); } update(1, idx[tu], idx[u], c); u = fa[tu]; tu = top[u]; } if (dep[u] > dep[v]) swap(u, v); update(1, idx[u], idx[v], c); } int main(void) { int n, m, a, b, c, i; char ops[10]; while (~scanf("%d%d", &n, &m)) { init(); for (i = 1; i <= n; ++i) scanf("%d", &arr[i]); for (i = 1; i < n; ++i) { scanf("%d%d", &a, &b); add(a, b); add(b, a); } dfs1(1, 0, 1); dfs2(1, 1); build(1, 1, n); for (i = 1; i <= n; ++i) update(1, idx[i], idx[i], arr[i]); while (m--) { scanf("%s", ops); if (ops[0] == 'Q') { scanf("%d%d", &a, &b); printf("%d\n", Find(a, b)); } else { scanf("%d%d%d", &a, &b, &c); solve(a, b, c); } } } return 0; }