BZOJ 3531: [Sdoi2014]旅行 权值线段树 + 树链剖分
Description
S国有N个城市,编号从1到N。城市间用N-1条双向道路连接,满足
从一个城市出发可以到达其它所有城市。每个城市信仰不同的宗教,如飞天面条神教、隐形独角兽教、绝地教都是常见的信仰。为了方便,我们用不同的正整数代表各种宗教, S国的居民常常旅行。旅行时他们总会走最短路,并且为了避免麻烦,只在信仰和他们相同的城市留宿。当然旅程的终点也是信仰与他相同的城市。S国政府为每个城市标定了不同的旅行评级,旅行者们常会记下途中(包括起点和终点)留宿过的城市的评级总和或最大值。
在S国的历史上常会发生以下几种事件:
”CC x c”:城市x的居民全体改信了c教;
”CW x w”:城市x的评级调整为w;
”QS x y”:一位旅行者从城市x出发,到城市y,并记下了途中留宿过的城市的评级总和;
”QM x y”:一位旅行者从城市x出发,到城市y,并记下了途中留宿过
的城市的评级最大值。
由于年代久远,旅行者记下的数字已经遗失了,但记录开始之前每座城市的信仰与评级,还有事件记录本身是完好的。请根据这些信息,还原旅行者记下的数字。 为了方便,我们认为事件之间的间隔足够长,以致在任意一次旅行中,所有城市的评级和信仰保持不变。
Input
输入的第一行包含整数N,Q依次表示城市数和事件数。
接下来N行,第i+l行两个整数Wi,Ci依次表示记录开始之前,城市i的
评级和信仰。
接下来N-1行每行两个整数x,y表示一条双向道路。
接下来Q行,每行一个操作,格式如上所述。
Output
对每个QS和QM事件,输出一行,表示旅行者记下的数字。
题解:对于每一个宗教分别开一个线段树.
下标为树剖序,权值为树上点权.
由于宗教数目是 $O(n)$ 的,动态开点即可.
#include <bits/stdc++.h> #define setIO(s) freopen(s".in","r",stdin) #define ll long long #define inf 100000000000 #define maxn 500000 #define N 200003 using namespace std; namespace Seg { #define lson t[x].l #define rson t[x].r int n, tot; struct Node { int l, r; ll sumv, maxv; }t[maxn << 2]; void pushup(int x) { t[x].sumv = t[lson].sumv + t[rson].sumv; t[x].maxv = max(t[lson].maxv, t[rson].maxv); } void ins(int &x, int l, int r, int p, ll v) { if(!x) x = ++ tot; if(l == r) { t[x].sumv = t[x].maxv = v; return ; } int mid = (l + r) >> 1; if(p <= mid) ins(lson, l, mid, p, v); else ins(rson, mid + 1, r, p, v); pushup(x); } void del(int x, int l, int r, int p) { if(l == r) { t[x].sumv = t[x].maxv = 0; return ; } int mid = (l + r) >> 1; if(p <= mid) del(lson, l, mid, p); else del(rson, mid + 1, r, p); pushup(x); } ll query_sum(int l, int r, int x, int L, int R) { if(!x) return 0; if(l >= L && r <= R) return t[x].sumv; ll tmp = 0; int mid = (l + r) >> 1; if(L <= mid) tmp += query_sum(l, mid, lson, L, R); if(R > mid) tmp += query_sum(mid + 1, r, rson, L, R); return tmp; } ll query_max(int l, int r, int x, int L, int R) { if(!x) return -inf; if(l >= L && r <= R) return t[x].maxv; ll tmp = -inf; int mid = (l + r) >> 1; if(L <= mid) tmp = max(tmp, query_max(l, mid, lson, L, R)); if(R > mid) tmp = max(tmp, query_max(mid + 1, r, rson, L, R)); return tmp; } #undef lson #undef rson }; char str[10]; int n, Q, edges, tim; int hd[maxn], to[maxn << 1], nex[maxn << 1], W[maxn], C[maxn], fa[maxn], dep[maxn]; int ln[maxn], dfn[maxn], top[maxn], bot[maxn], siz[maxn], hson[maxn], rt[maxn]; void add(int u, int v) { nex[++edges] = hd[u], hd[u] = edges, to[edges] = v; } void dfs1(int u, int ff) { siz[u] = 1, fa[u] = ff, dep[u] = dep[ff] + 1; for(int i = hd[u]; i ; i = nex[i]) { int v = to[i]; if(v == ff) continue; dfs1(v, u); siz[u] += siz[v]; if(siz[hson[u]] < siz[v]) hson[u] = v; } } void dfs2(int u, int tp) { top[u] = tp, ln[++tim] = u, dfn[u] = tim; Seg :: ins(rt[C[u]], 1, N, tim, 1ll*W[u]); if(hson[u]) dfs2(hson[u], tp), bot[u] = bot[hson[u]]; else bot[u] = u; for(int i = hd[u]; i ; i = nex[i]) { int v = to[i]; if(v == fa[u] || v == hson[u]) continue; dfs2(v, v); } } ll _query_sum(int x, int y) { int ty = C[y]; ll tmp = 0; // y is the deeper one while(top[x] ^ top[y]) { if(dep[top[x]] > dep[top[y]]) swap(x, y); tmp += Seg :: query_sum(1, N, rt[ty], dfn[top[y]], dfn[y]); y = fa[top[y]]; } if(dep[x] > dep[y]) swap(x, y); tmp += Seg :: query_sum(1, N, rt[ty], dfn[x], dfn[y]); return tmp; } ll _query_max(int x, int y) { int ty = C[y]; ll tmp = 0; while(top[x] ^ top[y]) { if(dep[top[x]] > dep[top[y]]) swap(x, y); tmp = max(tmp, Seg :: query_max(1, N, rt[ty], dfn[top[y]], dfn[y])); y = fa[top[y]]; } if(dep[x] > dep[y]) swap(x, y); tmp = max(tmp, Seg :: query_max(1, N, rt[ty], dfn[x], dfn[y])); return tmp; } int main() { // setIO("input"); scanf("%d%d",&n,&Q); for(int i = 1;i <= n; ++i) scanf("%d%d",&W[i],&C[i]); for(int i = 1, u, v; i < n; ++i) { scanf("%d%d",&u,&v), add(u, v), add(v, u); } Seg :: t[0].maxv = -inf; dfs1(1, 0), dfs2(1, 1); while(Q--) { scanf("%s",str); int x, w, c, y; if(str[1] == 'C') { scanf("%d%d",&x,&c); Seg :: del(rt[C[x]], 1, N, dfn[x]); C[x] = c; Seg :: ins(rt[C[x]], 1, N, dfn[x], W[x]); } if(str[1] == 'W') { scanf("%d%d",&x,&w); W[x] = w; Seg :: ins(rt[C[x]], 1, N, dfn[x], 1ll*W[x]); } if(str[1] == 'S') { scanf("%d%d",&x,&y), printf("%lld\n",_query_sum(x, y)); } if(str[1] == 'M') { scanf("%d%d",&x,&y), printf("%lld\n",_query_max(x, y)); } } return 0; }