BZOJ 1036: [ZJOI2008]树的统计Count
题意:
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
题解:
这是树链剖分的模板题,不过我树链剖分写挂了T_T,只有抄网上的版。。。
代码:
来源:http://coraon.com/zjoi-2008/
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> #define MAXN 30001 #define INF 0x3f3f3f3f #define lchild rt << 1, l, m #define rchild rt << 1 | 1, m + 1, r using namespace std; int n, w[MAXN], mw[MAXN]; vector<int>e[MAXN]; class Segment_Tree{ private: int sum[MAXN << 2], upper[MAXN << 2]; void push_up(int rt){ sum[rt] = sum[rt << 1] + sum[rt << 1 | 1]; upper[rt] = max(upper[rt << 1], upper[rt << 1 | 1]); } public: void build(int rt = 1, int l = 1, int r = n){ if(l == r){ sum[rt] = upper[rt] = mw[l]; return; } sum[rt] = 0; upper[rt] = -INF; int m = (l + r) >> 1; build(lchild); build(rchild); push_up(rt); } void update(int P, int val, int rt = 1, int l = 1, int r = n){ if(l == r) { sum[rt] = upper[rt] = val; return; } int m = (l + r) >> 1; if(P <= m) update(P, val, lchild); else update(P, val, rchild); push_up(rt); } int query(int L, int R, bool opt, int rt = 1, int l = 1, int r = n){ if(L <= l && r <= R){ if(opt) return upper[rt]; else return sum[rt]; } int m = (l + r) >> 1; if(opt){ int lans = -INF, rans = -INF; if(L <= m) lans = query(L, R, opt, lchild); if(R > m) rans = query(L, R, opt, rchild); return max(lans, rans); } else{ if(L > m) return query(L, R, opt, rchild); else if(R <= m) return query(L, R, opt, lchild); else return query(L, m, opt, lchild) + query(m + 1, R, opt, rchild); } } }; class HLD: public Segment_Tree{ public: int dep[MAXN], fa[MAXN], sz[MAXN]; int son[MAXN], top[MAXN], dfn[MAXN], dfs_clock; void init(){ memset(dep, 0, sizeof(dep)); memset(son, 0, sizeof(son)); dep[1] = 1; dfs_clock = 0; } void dfs1(int u){ sz[u] = 1; for(int i = 0; i < e[u].size(); i++){ int v = e[u][i]; if(dep[v]) continue; dep[v] = dep[u] + 1; fa[v] = u; dfs1(v); sz[u] += sz[v]; if(sz[son[u]] < sz[v]) son[u] = v; } } void dfs2(int u, int tp){ top[u] = tp; dfn[u] = ++dfs_clock; mw[dfn[u]] = w[u]; if(son[u]) dfs2(son[u], tp); //拉链 for(int i = 0; i < e[u].size(); i++){ int v = e[u][i]; if(v == fa[u] || v == son[u]) continue; dfs2(v, v); //建链 } } int getsum(int u, int v){ int ans = 0; while(top[u] != top[v]){ //一直爬直到在u, v同一条重链 if(dep[top[u]] > dep[top[v]]) swap(u, v); ans += query(dfn[top[v]], dfn[v], 0); v = fa[top[v]]; } if(dep[u] > dep[v]) swap(u, v); ans += query(dfn[u], dfn[v], 0); //属于同一条重链的时候直接区间询问 return ans; } int getmax(int u, int v){ int ans = -INF; while(top[u] != top[v]){ if(dep[top[u]] > dep[top[v]]) swap(u, v); ans = max(ans, query(dfn[top[v]], dfn[v], 1)); v = fa[top[v]]; } if(dep[u] > dep[v]) swap(u, v); ans = max(ans, query(dfn[u], dfn[v], 1)); return ans; } }hld; int main(){ #ifdef _DEBUG freopen("d:\\2008.txt", "r", stdin); #endif char opt[10]; int u, v, m; while(scanf("%d", &n) != EOF){ for(int i = 1; i <= n; i++) e[i].clear(); hld.init(); for(int i = 1; i < n; i++){ scanf("%d %d", &u, &v); e[u].push_back(v); e[v].push_back(u); } for(int i = 1; i <= n; i++) scanf("%d", w + i); hld.dfs1(1); hld.dfs2(1, 1); hld.build(); scanf("%d", &m); for(int i = 0; i < m; i++){ scanf("%s %d %d", opt, &u, &v); if(opt[0] == 'C') hld.update(hld.dfn[u], v); else if(opt[1] == 'M') printf("%d\n", hld.getmax(u, v)); else printf("%d\n", hld.getsum(u, v)); } } return 0; }