2008ZJOI树的统计

codevs 2460 树的统计

 http://codevs.cn/problem/2460/

2008年省队选拔赛浙江

 题目等级 : 大师 Master
 
题目描述 Description

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

  1. I.                    CHANGE u t : 把结点u的权值改为t
  2. II.                 QMAX u v: 询问从点u到点v的路径上的节点的最大权值
  3. III.               QSUM u v: 询问从点u到点v的路径上的节点的权值和

 

注意:从点u到点v的路径上的节点包括u和v本身

输入描述 Input Description

输入文件的第一行为一个整数n,表示节点的个数。

       接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

       接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。

       接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。

 

输出描述 Output Description

       对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

样例输入 Sample Input

4

1 2

2 3

4 1

4 2 1 3

12

QMAX 3 4

QMAX 3 3

QMAX 3 2

QMAX 2 3

QSUM 3 4

QSUM 2 1

CHANGE 1 5

QMAX 3 4

CHANGE 3 6

QMAX 3 4

QMAX 2 4

QSUM 3 4

样例输出 Sample Output

4

1

2

2

10

6

5

6

5

16

数据范围及提示 Data Size & Hint

对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

树链剖分模板题,使用线段树维护

代码中变量含义:

n,节点个数    head[N],链表;
fa[N],存储父节点     son[N],统计包括自己在内的儿子个数  

id[N], 给每个节点按树链剖分的顺序重新编号    dep[N],节点深度  

bl[N],该节点所在重链的顶部节点    a[N],初始节点权值

struct node{int to,next;}e[N*2]; 链表
struct tree{int l,r,sum,maxx,size;}tr[N*2]; 线段树

注:代码中采用2*n空间方式建立线段树,k的左儿子为k+1,右儿子为k+k左儿子节点数*2
dfs_cnt,建立线段树时节点编号    cnt,链表    sz,树链剖分中节点访问顺序

#include<cstdio>
#include<algorithm>
#define N 30001
using namespace std;
int n,head[N];
int fa[N],son[N],id[N],dep[N],bl[N],a[N];
struct node{int to,next;}e[N*2];
struct tree{int l,r,sum,maxx,size;}tr[N*2];
int dfs_cnt,cnt,sz;
inline void insert(int u,int v)
{
    e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt;
    e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt;
}
void init()
{
    scanf("%d",&n);
    int u,v;
    for(int i=1;i<n;i++) 
    {
        scanf("%d%d",&u,&v);
        insert(u,v);
    }
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
}
inline void build(int l,int r)
{
    dfs_cnt++;tr[dfs_cnt].l=l;tr[dfs_cnt].r=r;tr[dfs_cnt].size=r-l+1;
    if(l==r) return;
    int mid=l+r>>1;
    build(l,mid);build(mid+1,r);
}
inline void dfs1(int k)
{
    son[k]++;
    for(int i=head[k];i;i=e[i].next)
    {
        if(e[i].to==fa[k]) continue;
        fa[e[i].to]=k;
        dep[e[i].to]=dep[k]+1;
        dfs1(e[i].to);
        son[k]+=son[e[i].to];
    }
}
inline void dfs2(int k,int chain)
{
    int m=0;sz++;
    id[k]=sz;
    bl[k]=chain;
    for(int i=head[k];i;i=e[i].next)
    {
        if(e[i].to==fa[k]) continue;
        if(son[e[i].to]>son[m]) m=e[i].to;
    }
    if(!m) return;
    dfs2(m,chain);
    for(int i=head[k];i;i=e[i].next)
    {
        if(e[i].to==m||e[i].to==fa[k]) continue;
        dfs2(e[i].to,e[i].to);
    } 
}
inline void change(int k,int pos,int w)
{
    if(tr[k].l==tr[k].r) {tr[k].sum=tr[k].maxx=w;return;}
    int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].size*2;
    if(pos<=mid) change(l,pos,w);
    else change(r,pos,w);
    tr[k].sum=tr[l].sum+tr[r].sum;
    tr[k].maxx=max(tr[l].maxx,tr[r].maxx); 
} 
inline int querymx(int k,int opl,int opr)
{
    if(tr[k].l==opl&&tr[k].r==opr) return tr[k].maxx;
    int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].size*2;
    if(opr<=mid)  return querymx(l,opl,opr);
    else if(opl>mid) return querymx(r,opl,opr);
    else return max(querymx(l,opl,mid),querymx(r,mid+1,opr)); 
}
inline int querysum(int k,int opl,int opr)
{
    if(tr[k].l==opl&&tr[k].r==opr) return tr[k].sum;
    int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].size*2;
    if(opr<=mid) return querysum(l,opl,opr);
    else if(opl>mid) return querysum(r,opl,opr);
    else return querysum(l,opl,mid)+querysum(r,mid+1,opr);
}
inline void solvemx(int u,int v)
{
    int ans=-0x7fffffff;
    while(bl[u]!=bl[v])
    {
        if(dep[bl[u]]<dep[bl[v]]) swap(u,v);
        ans=max(ans,querymx(1,id[bl[u]],id[u]));
        u=fa[bl[u]];
    }
    if(id[u]>id[v]) swap(u,v);
    ans=max(ans,querymx(1,id[u],id[v]));
    printf("%d\n",ans);
}
inline void solvesum(int u,int v)
{
    int ans=0;
    while(bl[u]!=bl[v])
    {
        if(dep[bl[u]]<dep[bl[v]]) swap(u,v);
        ans+=querysum(1,id[bl[u]],id[u]);
        u=fa[bl[u]];
    }
    if(id[u]>id[v]) swap(u,v);
    ans+=querysum(1,id[u],id[v]);
    printf("%d\n",ans);
}
void solve()
{
    build(1,n);
    for(int i=1;i<=n;i++) 
        change(1,id[i],a[i]);
    int q,u,v;char c[7];
    scanf("%d",&q);
    for(int i=1;i<=q;i++)
    {
        scanf("%s%d%d",c,&u,&v);
        if(c[0]=='C') change(1,id[u],v);
        else if (c[1]=='M') solvemx(u,v); 
        else solvesum(u,v);
    }
}
int main()
{
    init();
    dfs1(1);
    dfs2(1,1);
    solve();
}

 

posted @ 2017-02-09 17:48  TRTTG  阅读(261)  评论(0编辑  收藏  举报