[树链剖分][线段树] 洛谷 P2590 树的统计
题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入输出格式
输入格式:
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
输出格式:
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
输入输出样例
说明
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
题解
- 没什么好说的板题
代码
1 #include <cstdio> 2 #include <iostream> 3 #include <cstring> 4 #define N 100010 5 #define inf 1000000000 6 using namespace std; 7 struct edge{ int to,from; }e[N]; 8 int sum[4*N],mx[4*N],dep[N],size[N],fa[N],head[N],tid[N],rank[N],top[N],a[4*N],son[N],cnt,tot,n,m,q; 9 void insert(int x,int y) { e[++cnt].to=y; e[cnt].from=head[x]; head[x]=cnt; } 10 void dfs(int x,int pre) 11 { 12 fa[x]=pre,dep[x]=dep[pre]+1,size[x]=1; 13 for (int i=head[x];i;i=e[i].from) 14 if (e[i].to!=pre) 15 { 16 dfs(e[i].to,x),size[x]+=size[e[i].to]; 17 if (size[e[i].to]>size[son[x]]) son[x]=e[i].to; 18 } 19 } 20 void dfs1(int x,int pre) 21 { 22 tid[x]=++tot,rank[tot]=x,top[x]=pre; 23 if (son[x]) dfs1(son[x],pre); 24 for (int i=head[x];i;i=e[i].from) 25 if (e[i].to!=son[x]&&e[i].to!=fa[x]) 26 dfs1(e[i].to,e[i].to); 27 } 28 void build(int d,int l,int r) 29 { 30 if (l==r) 31 { 32 sum[d]=mx[d]=a[rank[l]]; 33 return; 34 } 35 int mid=(l+r)>>1; 36 build(d*2,l,mid),build(d*2+1,mid+1,r); 37 sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]); 38 } 39 void updata(int d,int l,int r,int x,int y) 40 { 41 if (l==r) 42 { 43 sum[d]=mx[d]=y; 44 return; 45 } 46 int mid=(l+r)>>1; 47 if (x<=mid) updata(d*2,l,mid,x,y); else updata(d*2+1,mid+1,r,x,y); 48 sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]); 49 } 50 int querymax(int d,int l,int r,int x,int y) 51 { 52 int ans=-inf,mid=(l+r)>>1; 53 if (x<=l&&r<=y) return mx[d]; 54 if (x<=mid) ans=max(ans,querymax(d*2,l,mid,x,y)); 55 if (y>mid) ans=max(ans,querymax(d*2+1,mid+1,r,x,y)); 56 sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]); 57 return ans; 58 } 59 int querysum(int d,int l,int r,int x,int y) 60 { 61 int ans=0,mid=(l+r)>>1; 62 if (x<=l&&r<=y) return sum[d]; 63 if (x<=mid) ans+=querysum(d*2,l,mid,x,y); 64 if (y>mid) ans+=querysum(d*2+1,mid+1,r,x,y); 65 sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]); 66 return ans; 67 } 68 int getsum(int u,int v) 69 { 70 int ans=0; 71 while (top[u]!=top[v]) 72 { 73 if (dep[top[u]]<dep[top[v]]) swap(u,v); 74 ans+=querysum(1,1,n,tid[top[u]],tid[u]); 75 u=fa[top[u]]; 76 } 77 if (dep[u]<dep[v]) swap(u,v); 78 ans+=querysum(1,1,n,tid[v],tid[u]); 79 return ans; 80 } 81 int getmax(int u,int v) 82 { 83 int ans=-inf; 84 while (top[u]!=top[v]) 85 { 86 if (dep[top[u]]<dep[top[v]]) swap(u,v); 87 ans=max(ans,querymax(1,1,n,tid[top[u]],tid[u])); 88 u=fa[top[u]]; 89 } 90 if (dep[u]<dep[v]) swap(u,v); 91 ans=max(ans,querymax(1,1,n,tid[v],tid[u])); 92 return ans; 93 } 94 int main() 95 { 96 scanf("%d",&n); 97 for (int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),insert(x,y),insert(y,x); 98 for (int i=1;i<=n;i++) scanf("%d",&a[i]); 99 dfs(1,0); fa[1]=1,dfs1(1,1); build(1,1,n); 100 scanf("%d",&q); 101 while (q--) 102 { 103 char ch[10]; int x,y; 104 scanf("%s%d%d",ch,&x,&y); 105 if (ch[1]=='H') updata(1,1,n,tid[x],y); 106 if (ch[1]=='M') printf("%d\n",getmax(x,y)); 107 if (ch[1]=='S') printf("%d\n",getsum(x,y)); 108 } 109 }