POJ 3237 Tree
学了一下树链剖分。就是把树剖成链,然后用线段树、树状数组、splay等数据结构来维护。
// POJ 3237 TREE /**DESC: 给出一棵树,有三种操作: 1:第i条边的权值修改成v. 2:a 到 b 的路径上的权值全都取反。 3:在 a 到 b的路径上的权值找最大。 */ /** 思路:线段树维护树链剖分。 * 道理我都懂,就是代码麻烦了有点。T_T */ #include <stdio.h> #include <string.h> #include <iostream> #include <vector> #define maxn 10010 using namespace std; struct Edge{ int u, v; int nxt; }edge[maxn*2]; int head[maxn]; int tot; void addEdge(int u, int v) { edge[tot].u = u; edge[tot].v = v; edge[tot].nxt = head[u]; head[u] = tot++; } //树链剖分 把树剖成链 int fa[maxn]; //dfs1 int deep[maxn]; int num[maxn]; int son[maxn]; int top[maxn]; //dfs2 int p[maxn]; int fp[maxn]; int pos; void dfs1(int u, int pre, int d) { deep[u] = d; fa[u] = pre; num[u] = 1; for (int i=head[u]; i!=-1; i=edge[i].nxt) { int v = edge[i].v; if (v != pre) { dfs1(v, u, d+1); num[u] += num[v]; if (son[u] == -1 || num[v] > num[son[u]]) { son[u] = v; } } } } void dfs2(int u, int sp) { top[u] = sp; p[u] = pos++; /// u和父亲结点的边在线段树中的位置 fp[p[u]] = u; /// 和fa[]数组相反,线段树中的第fp[u]条边是原树中u点和父亲的连边 if (son[u] == -1) return; dfs2(son[u], sp); for (int i=head[u]; i!=-1; i=edge[i].nxt) { int v = edge[i].v; if (v != son[u] && v != fa[u]) { dfs2(v, v); } } } //线段树 struct Node { int l, r; int maxm; int minn; int ne; }segTree[maxn*4]; void build(int rt, int l, int r) { segTree[rt].l = l; segTree[rt].r = r; segTree[rt].maxm = 0; segTree[rt].minn = 0; if (l == r) return; int mid = ((l+r)>>1); build(rt<<1, l, mid); build((rt<<1)|1, mid+1, r); } void push_down(int i) { if (segTree[i].l == segTree[i].r) return; if (segTree[i].ne) { segTree[i<<1].maxm = -segTree[i<<1].maxm; segTree[i<<1].minn = -segTree[i<<1].minn; swap(segTree[i<<1].maxm, segTree[i<<1].minn); segTree[(i<<1)|1].maxm = -segTree[(i<<1|1)].maxm; segTree[(i<<1)|1].minn = -segTree[(i<<1|1)].minn; swap(segTree[(i<<1)|1].maxm, segTree[(i<<1)|1].minn); segTree[i<<1].ne ^= 1; //左右子结点的延迟标记更新 segTree[(i<<1)|1].ne ^= 1; segTree[i].ne = 0; //// } } void push_up(int i) { segTree[i].maxm = max(segTree[i<<1].maxm, segTree[(i<<1)|1].maxm); segTree[i].minn = min(segTree[i<<1].minn, segTree[(i<<1)|1].minn); } void update(int i, int k, int val) { if (segTree[i].l == k && segTree[i].r == k) { segTree[i].maxm = val; segTree[i].minn = val; segTree[i].ne = 0;/// return; } push_down(i); //向下延迟标记 int mid = (segTree[i].l + segTree[i].r) / 2; if (k <= mid) update(i<<1, k, val); else update((i<<1)|1, k, val); push_up(i); //向上延迟标记 } void init() { memset(head, -1, sizeof(head)); tot = 0; pos = 0; memset(son, -1, sizeof(son)); } int e[maxn][3]; void ne_update(int rt, int l, int r) { //把线段树的[l, r]区间取反 if (segTree[rt].l == l && segTree[rt].r == r) { segTree[rt].maxm = -segTree[rt].maxm; segTree[rt].minn = -segTree[rt].minn; swap(segTree[rt].maxm, segTree[rt].minn); segTree[rt].ne ^= 1; // 延迟标记 return; } push_down(rt); /// int mid = (segTree[rt].l + segTree[rt].r) / 2; if (r <= mid) { //全都在左区间 ne_update(rt<<1, l, r); }else if (l > mid) { ne_update((rt<<1)|1, l, r); }else { ne_update(rt<<1, l, mid); ne_update((rt<<1)|1, mid+1, r); } push_up(rt); /// } void Negate(int u, int v) { int f1 = top[u], f2 = top[v]; while(f1 != f2) { if (deep[f1] < deep[f2]) { //使得depp[f1] > deep[f2] swap(f1, f2); swap(u, v); } ne_update(1, p[f1], p[u]); /// u = fa[f1], f1 = top[u]; } if (u == v) return; if (deep[u] > deep[v]) swap(u, v); //使得deep[u] < deep[v] ne_update(1, p[son[u]], p[v]); } int query(int rt, int l, int r) { // 查询线段树中[l, r] 的最大值 if (segTree[rt].l == l && segTree[rt].r == r) return segTree[rt].maxm; push_down(rt); /// int mid = (segTree[rt].l + segTree[rt].r) / 2; if (r <= mid) { return query(rt<<1, l, r); }else if (l > mid) { return query((rt<<1)|1, l, r); }else { return max(query(rt<<1, l, mid), query((rt<<1)|1, mid+1, r)); } push_up(rt); /// } int findMax(int u, int v) { int f1 = top[u], f2 = top[v]; int tmp = -100000000; while(f1 != f2) { if (deep[f1] < deep[f2]) { swap(f1, f2); swap(u, v); } tmp = max(tmp, query(1, p[f1], p[u])); u = fa[f1]; f1 = top[u]; } if (u == v) return tmp; if (deep[u] > deep[v]) swap(u, v); return max(tmp, query(1, p[son[u]], p[v])); } int main() { // freopen("in.cpp", "r", stdin); int t; scanf("%d", &t); while(t--) { int n; scanf("%d", &n); // input init(); for (int i=0; i<n; ++i) { scanf("%d%d%d", &e[i][0], &e[i][1], &e[i][2]); addEdge(e[i][0], e[i][1]); addEdge(e[i][1], e[i][0]); } dfs1(1, 0, 0); dfs2(1, 1); build(1, 0, pos-1); //线段树赋值 for (int i=0; i<n-1; ++i) { if (deep[e[i][0]] > deep[e[i][1]]) { swap(e[i][0], e[i][1]); } update(1, p[e[i][1]], e[i][2]); } char op[10]; int u, v; while(~scanf("%s", op)) { // printf("%s\n",op); if (op[0] == 'D') break; scanf("%d%d", &u, &v); if (op[0] == 'C') { update(1, p[e[u-1][1]], v); }else if (op[0] == 'N') { Negate(u, v); }else printf("%d\n", findMax(u, v)); } } return 0; }
HYSBZ 1036 树的统计Count
和上一题不同的是,这是点权,线段树维护的是每个点构成的数组。这个题因为只有单点修改,所以不需要延迟标记。
#include <stdio.h> #include <string.h> #include <iostream> #include <algorithm> #include <vector> #include <queue> #include <set> #include <map> #include <string> #include <math.h> #include <stdlib.h> using namespace std; const int MAXN = 30010; struct Edge { int to,next; }edge[MAXN*2]; int head[MAXN],tot; int top[MAXN]; //top[v] 表示v所在的重链的顶端节点 int fa[MAXN]; //父亲节点 int deep[MAXN];//深度 int num[MAXN]; //num[v]表示以v为根的子树的节点数 int p[MAXN]; //p[v]表示v在线段树中的位置 int fp[MAXN];//和p数组相反 int son[MAXN];//重儿子 int pos; void init() { tot = 0; memset(head,-1,sizeof(head)); pos = 0; memset(son,-1,sizeof(son)); } void addedge(int u,int v) { edge[tot].to = v; edge[tot].next = head[u]; head[u] = tot++; } void dfs1(int u,int pre,int d) //第一遍dfs求出fa,deep,num,son { deep[u] = d; fa[u] = pre; num[u] = 1; for(int i = head[u];i != -1;i = edge[i].next) { int v = edge[i].to; if(v != pre) { dfs1(v,u,d+1); num[u] += num[v]; if(son[u] == -1 || num[v] > num[son[u]]) son[u] = v; } } } void getpos(int u,int sp) { top[u] = sp; p[u] = pos++; /// fp[p[u]] = u; /// if(son[u] == -1) return; getpos(son[u],sp); for(int i = head[u]; i != -1 ; i = edge[i].next) { int v = edge[i].to; if(v != son[u] && v != fa[u]) getpos(v,v); } } struct Node { int l,r; int sum; int Max; }segTree[MAXN*3]; void push_up(int i) { segTree[i].sum = segTree[i<<1].sum + segTree[(i<<1)|1].sum; segTree[i].Max = max(segTree[i<<1].Max,segTree[(i<<1)|1].Max); } int s[MAXN]; void build(int i,int l,int r) { segTree[i].l = l; segTree[i].r = r; if(l == r) { segTree[i].sum = segTree[i].Max = s[fp[l]]; ///赋值 return ; } int mid = (l + r)/2; build(i<<1,l,mid); build((i<<1)|1,mid+1,r); push_up(i); } void update(int i,int k,int val)//更新线段树的第k个值为val { if(segTree[i].l == k && segTree[i].r == k) { segTree[i].sum = segTree[i].Max = val; return; } int mid = (segTree[i].l + segTree[i].r)/2; if(k <= mid)update(i<<1,k,val); else update((i<<1)|1,k,val); push_up(i); } int queryMax(int i,int l,int r)//查询线段树[l,r]区间的最大值 { if(segTree[i].l == l && segTree[i].r == r) { return segTree[i].Max; } int mid = (segTree[i].l + segTree[i].r)/2; if(r <= mid) return queryMax(i<<1,l,r); else if(l > mid)return queryMax((i<<1)|1,l,r); else return max(queryMax(i<<1,l,mid),queryMax((i<<1)|1,mid+1,r)); } int querySum(int i,int l,int r) //查询线段树[l,r]区间的和 { if(segTree[i].l == l && segTree[i].r == r) return segTree[i].sum; int mid = (segTree[i].l + segTree[i].r)/2; if(r <= mid)return querySum(i<<1,l,r); else if(l > mid)return querySum((i<<1)|1,l,r); else return querySum(i<<1,l,mid) + querySum((i<<1)|1,mid+1,r); } int findMax(int u,int v)//查询u->v路径上节点的最大权值 { int f1 = top[u] , f2 = top[v]; int tmp = -1000000000; while(f1 != f2) { if(deep[f1] < deep[f2]) { swap(f1,f2); swap(u,v); } tmp = max(tmp,queryMax(1,p[f1],p[u])); u = fa[f1]; f1 = top[u]; } if(deep[u] > deep[v]) swap(u,v); return max(tmp,queryMax(1,p[u],p[v])); /// } int findSum(int u,int v) //查询u->v路径上节点的权值的和 { int f1 = top[u], f2 = top[v]; int tmp = 0; while(f1 != f2) { if(deep[f1] < deep[f2]) { swap(f1,f2); swap(u,v); } tmp += querySum(1,p[f1],p[u]); u = fa[f1]; f1 = top[u]; } if(deep[u] > deep[v]) swap(u,v); return tmp + querySum(1,p[u],p[v]); /// } int main() { //freopen("in.txt","r",stdin); //freopen("out.txt","w",stdout); int n; int q; char op[20]; int u,v; while(scanf("%d",&n) == 1) { init(); for(int i = 1;i < n;i++) { scanf("%d%d",&u,&v); addedge(u,v); addedge(v,u); } for(int i = 1;i <= n;i++) scanf("%d",&s[i]); dfs1(1,0,0); getpos(1,1); build(1,0,pos-1); scanf("%d",&q); while(q--) { scanf("%s%d%d",op,&u,&v); if(op[0] == 'C') update(1,p[u],v);//修改单点的值 else if(strcmp(op,"QMAX") == 0) printf("%d\n",findMax(u,v));//查询u->v路径上点权的最大值 else printf("%d\n",findSum(u,v));//查询路径上点权的和 } } return 0; }
参考链接:http://www.cnblogs.com/kuangbin/category/507663.html