[树链剖分][线段树] 洛谷 P2590 树的统计

题目描述

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

输入输出格式

输入格式:

 

输入文件的第一行为一个整数n,表示节点的个数。

接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来一行n个整数,第i个整数wi表示节点i的权值。

接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。

 

输出格式:

 

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

 

输入输出样例

输入样例#1:
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
输出样例#1: 
4
1
2
2
10
6
5
6
5
16

说明

对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

 

 

题解

  • 没什么好说的板题

代码

  1 #include <cstdio> 
  2 #include <iostream>
  3 #include <cstring>
  4 #define N 100010
  5 #define inf 1000000000
  6 using namespace std;
  7 struct edge{ int to,from; }e[N];
  8 int sum[4*N],mx[4*N],dep[N],size[N],fa[N],head[N],tid[N],rank[N],top[N],a[4*N],son[N],cnt,tot,n,m,q;
  9 void insert(int x,int y) { e[++cnt].to=y; e[cnt].from=head[x]; head[x]=cnt; }
 10 void dfs(int x,int pre)
 11 {
 12     fa[x]=pre,dep[x]=dep[pre]+1,size[x]=1;
 13     for (int i=head[x];i;i=e[i].from) 
 14         if (e[i].to!=pre)
 15         {
 16             dfs(e[i].to,x),size[x]+=size[e[i].to];
 17             if (size[e[i].to]>size[son[x]]) son[x]=e[i].to;
 18         }
 19 }
 20 void dfs1(int x,int pre)
 21 {
 22     tid[x]=++tot,rank[tot]=x,top[x]=pre;
 23     if (son[x]) dfs1(son[x],pre);
 24     for (int i=head[x];i;i=e[i].from)
 25         if (e[i].to!=son[x]&&e[i].to!=fa[x])
 26             dfs1(e[i].to,e[i].to);
 27 }
 28 void build(int d,int l,int r)
 29 {
 30     if (l==r)
 31     {
 32         sum[d]=mx[d]=a[rank[l]];
 33         return;
 34     }
 35     int mid=(l+r)>>1;
 36     build(d*2,l,mid),build(d*2+1,mid+1,r);
 37     sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]);
 38 }
 39 void updata(int d,int l,int r,int x,int y)
 40 {
 41     if (l==r)
 42     {
 43         sum[d]=mx[d]=y;
 44         return;
 45     }
 46     int mid=(l+r)>>1;
 47     if (x<=mid) updata(d*2,l,mid,x,y); else updata(d*2+1,mid+1,r,x,y);
 48     sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]);
 49 }
 50 int querymax(int d,int l,int r,int x,int y)
 51 {
 52     int ans=-inf,mid=(l+r)>>1;
 53     if (x<=l&&r<=y) return mx[d];
 54     if (x<=mid)    ans=max(ans,querymax(d*2,l,mid,x,y));
 55     if (y>mid) ans=max(ans,querymax(d*2+1,mid+1,r,x,y));
 56     sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]);
 57     return ans;
 58 }
 59 int querysum(int d,int l,int r,int x,int y)
 60 {
 61     int ans=0,mid=(l+r)>>1;
 62     if (x<=l&&r<=y) return sum[d];
 63     if (x<=mid)    ans+=querysum(d*2,l,mid,x,y);
 64     if (y>mid) ans+=querysum(d*2+1,mid+1,r,x,y);
 65     sum[d]=sum[d*2]+sum[d*2+1],mx[d]=max(mx[d*2],mx[d*2+1]);    
 66     return ans;
 67 }
 68 int getsum(int u,int v)
 69 {
 70     int ans=0;
 71     while (top[u]!=top[v])
 72     {
 73         if (dep[top[u]]<dep[top[v]]) swap(u,v);
 74         ans+=querysum(1,1,n,tid[top[u]],tid[u]);
 75         u=fa[top[u]];
 76     }
 77     if (dep[u]<dep[v]) swap(u,v);
 78     ans+=querysum(1,1,n,tid[v],tid[u]);
 79     return ans;
 80 }
 81 int getmax(int u,int v)
 82 {
 83     int ans=-inf;
 84     while (top[u]!=top[v])
 85     {
 86         if (dep[top[u]]<dep[top[v]]) swap(u,v);
 87         ans=max(ans,querymax(1,1,n,tid[top[u]],tid[u]));
 88         u=fa[top[u]];
 89     }
 90     if (dep[u]<dep[v]) swap(u,v);
 91     ans=max(ans,querymax(1,1,n,tid[v],tid[u]));
 92     return ans;
 93 }
 94 int main()
 95 {
 96     scanf("%d",&n);
 97     for (int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),insert(x,y),insert(y,x);
 98     for (int i=1;i<=n;i++) scanf("%d",&a[i]);
 99     dfs(1,0); fa[1]=1,dfs1(1,1); build(1,1,n);
100     scanf("%d",&q);
101     while (q--)
102     {
103         char ch[10]; int x,y;
104         scanf("%s%d%d",ch,&x,&y);
105         if (ch[1]=='H') updata(1,1,n,tid[x],y);
106         if (ch[1]=='M') printf("%d\n",getmax(x,y));
107         if (ch[1]=='S') printf("%d\n",getsum(x,y));
108     }
109 }

 

posted @ 2018-10-17 16:00  BEYang_Z  阅读(126)  评论(0编辑  收藏  举报