[ZJOI2008]树的统计

题目描述

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

输入输出格式

**输入格式:

** 输入文件的第一行为一个整数n,表示节点的个数。

接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来一行n个整数,第i个整数wi表示节点i的权值。

接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。

输出格式:

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

输入输出样例

输入样例#1:

4 1 2 2 3 4 1 4 2 1 3

12 QMAX 3 4

QMAX 3 3

QMAX 3 2

QMAX 2 3

QSUM 3 4

QSUM 2 1

CHANGE 1 5

QMAX 3 4

CHANGE 3 6

QMAX 3 4

QMAX 2 4

QSUM 3 4

输出样例#1:

4 1 2 2 10 6 5 6 5 16

solution

最喜欢一些难度虚高的题,比如说这一道。

这道题就是最简单的树链剖分了,基本不需要什么其它的技巧了

然而我的代码依旧在BZOJ上RE。

题目要求查询最大值和区间和,活生生的套路,只要会基本的树链剖分这道题就可以A掉

值得注意的是最大值初始赋为负数,因为权值可以为负(不然只有十分)

代码

#include<bits/stdc++.h>

using namespace std;

const int MAXN = 30000 + 10;

const int inf = 1<<30;

inline int read()
{
    char ch;
    int fl=1;
    int xx=0;
    do{
      ch= getchar();
      if(ch=='-')
        fl=-1;
    }while(ch<'0'||ch>'9');
    do{
        xx=(xx<<3)+(xx<<1)+ch-'0';
        ch=getchar();
    }while(ch>='0'&&ch<='9');
    return xx*fl;
}

inline int Max(int a,int b)
{
    if(a>=b) return a;
    else return b;
}

int n,q; 

string op;

int w[MAXN];

struct node
{
    int to;
    int next;
};node g[MAXN*2];

int cnt=0,head[MAXN];

inline void addedge(int a,int b)
{
    ++cnt;g[cnt].to=b;g[cnt].next=head[a];head[a]=cnt;return;
}

int fa[MAXN],dep[MAXN],son[MAXN],tot[MAXN];

#define v g[i].to

inline void dfs(int now)
{
    son[now]=0,tot[now]=1;
    for(int i=head[now];i>0;i=g[i].next){
        if(v!=fa[now]){
            fa[v]=now;
            dep[v]=dep[now]+1;
            dfs(v);
            if(tot[son[now]]<tot[v]) son[now]=v;
            tot[now]+=tot[v];
        }
    }
}

int num[MAXN],top[MAXN],h=0;

inline void dfs2(int now,int t)
{
    top[now]=t;++h;num[now]=h;
    if(son[now]!=0) dfs2(son[now],t);
    for(int i=head[now];i>0;i=g[i].next)
    {
        if(v!=fa[now]&&v!=son[now])
            dfs2(v,v);
    }
}

struct s_tree
{
    int l;
    int r;
    int sum;
    long long maxn;
    inline int mid()
    {
        return (l+r)>>1;
    }
};s_tree tree[MAXN*4];

#define lc o<<1
#define rc o<<1|1

inline void build(int o,int l,int r)
{
    tree[o].l=l;
    tree[o].r=r;
    if(l==r)
    {
        tree[o].maxn=-inf;
        tree[o].sum=0;
    }
    else
    {
        int mid=tree[o].mid();
        build(lc,l,mid);
        build(rc,mid+1,r);
    }
}

inline void change(int o,int x,int y)
{
    if(tree[o].l==tree[o].r){
        tree[o].maxn=tree[o].sum=y;
        return;
    }
    else
    {
        int mid=tree[o].mid();
        if(x<=mid) change(lc,x,y);
        else change(rc,x,y);
        tree[o].maxn=Max(tree[lc].maxn,tree[rc].maxn);
        tree[o].sum=tree[lc].sum+tree[rc].sum;
    }
}

inline long long findmax(int o,int x,int y)
{
    int l=tree[o].l;
    int r=tree[o].r;
    if(x==l&&y==r) return tree[o].maxn;
    else
    {
        int mid=tree[o].mid();
        if(x>mid) return findmax(rc,x,y);
        else if(y<=mid) return findmax(lc,x,y);
        else return Max(findmax(lc,x,mid),findmax(rc,mid+1,y));
    }
}

inline long long findsum(int o,int x,int y)
{
    int l=tree[o].l;
    int r=tree[o].r;
    if(x==l&&y==r) return tree[o].sum;
    else
    {
        int mid=tree[o].mid();
        if(x>mid) return findsum(rc,x,y);
        else if(y<=mid) return findsum(lc,x,y);
        else return (findsum(lc,x,mid)+findsum(rc,mid+1,y));
    }
}

inline long long qmax(int x,int y)
{
    int tx=top[x],ty=top[y],ans=-inf;
    while(tx!=ty){
        if(dep[tx]>dep[ty]) 
        {
            swap(x,y);
            swap(tx,ty);
        }
        ans=Max(ans,findmax(1,num[ty],num[y]));
        y=fa[ty];ty=top[y];
    }
    if(x==y){
        return Max(ans,findmax(1,num[x],num[x]));
    }
    else{
        if(dep[x]>dep[y]) swap(x,y);
        return Max(ans,findmax(1,num[x],num[y]));
    }
}

inline long long qsum(int x,int y)
{
    int tx=top[x],ty=top[y];
    long long ans=0;
    while(tx!=ty){
        if(dep[tx]>dep[ty]) 
        {
            swap(x,y);
            swap(tx,ty);
        }
        ans=ans+findsum(1,num[ty],num[y]);
        y=fa[ty];ty=top[y];
    }
    if(x==y){
        return ans+findsum(1,num[x],num[x]);
    }
    else{
        if(dep[x]>dep[y]) swap(x,y);
        return ans+findsum(1,num[x],num[y]);
    }
}

int main()
{
    n=read();
    for(int i=1;i<=n;i++)
    {
        head[i]=-1;
    }
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        addedge(x,y);
        addedge(y,x);
    }
    fa[1]=0;dep[1]=1;
    dfs(1);
    dfs2(1,1);
    build(1,1,h);
    for(int i=1;i<=n;i++)
    {
        w[i]=read();
        change(1,num[i],w[i]);
    }
    q=read();
    for(int i=1;i<=q;i++)
    {
        cin>>op;
        int x=read(),y=read();
        if(op=="QMAX")
            printf("%lld\n",qmax(x,y));
        else if(op=="QSUM")
            printf("%lld\n",qsum(x,y));
        else
            change(1,num[x],y);
    }
}

 

posted @ 2018-04-17 19:21  wlzs1432  阅读(170)  评论(0编辑  收藏  举报