【BZOJ 1036】 树的统计count

题目

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

分析

树链剖分

代码

 

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<vector>
using namespace std;
#define MN 30000
#define fINF -30000000
int fa[MN+5],W[MN+5],ans[MN+5],dfn[MN+5],sons[MN+5],fl[MN+5],head[MN+5],ccnt=0,cnt=0;
struct data{int to,next;}e[MN*2+10];
void ins(int u,int v){
    e[++ccnt].to=v;e[ccnt].next=head[u];head[u]=ccnt;
    e[++ccnt].to=u;e[ccnt].next=head[v];head[v]=ccnt;
}
struct TREE{int val,max;}t[MN*4+10];
int n,q;
void update(int k,int l,int r,int q,int v){
    if(l==r) {t[k].val=t[k].max=v;return;}
    int mid=(l+r)/2;
    if(q<=mid) update(k<<1,l,mid,q,v);
    if(q>mid) update(k<<1|1,mid+1,r,q,v);
    t[k].val=t[k<<1].val+t[k<<1|1].val;
    t[k].max=max(t[k<<1].max,t[k<<1|1].max);
}
void dfs1(int x){
    sons[x]=1;
    for(int i=head[x];i;i=e[i].next){
        if(e[i].to==fa[x]) continue;
        fl[e[i].to]=fl[x]+1; fa[e[i].to]=x;
        dfs1(e[i].to);
        sons[x]+=sons[e[i].to];
    }
}
void dfs2(int x,int chain){
    int k=0;
    dfn[x]=++cnt;
    ans[x]=chain;
    for(int i=head[x];i;i=e[i].next)
        if(fl[e[i].to]>fl[x]&&sons[e[i].to]>sons[k]) k=e[i].to;
    if(k==0) return;
    dfs2(k,chain);
    for(int i=head[x];i;i=e[i].next)
        if(fl[e[i].to]>fl[x]&&k!=e[i].to)
            dfs2(e[i].to,e[i].to);
}
int query_max(int k,int l,int r,int a,int b){
    if(a<=l&&r<=b) return t[k].max;
    int m=(l+r)/2,anss=fINF;
    if(a<=m) anss=max(anss,query_max(k<<1,l,m,a,b));
    if(m<b) anss=max(anss,query_max(k<<1|1,m+1,r,a,b));
    return anss;
}
int query_sum(int k,int l,int r,int a,int b){
    if(a<=l&&r<=b) return t[k].val;
    int m=(l+r)/2,anss=0;
    if(a<=m) anss+=query_sum(k<<1,l,m,a,b);
    if(m<b) anss+=query_sum(k<<1|1,m+1,r,a,b);
    return anss;
}
int find_sum(int x,int y){
    int sum=0;
    while(ans[x]!=ans[y]){
        if(fl[ans[x]]<fl[ans[y]]) swap(x,y);
        sum+=query_sum(1,1,n,dfn[ans[x]],dfn[x]);
        x=fa[ans[x]];
    }
    if(dfn[x]>dfn[y]) swap(x,y);
    sum+=query_sum(1,1,n,dfn[x],dfn[y]);
    return sum;
}
int find_max(int x,int y){
    int mx=fINF;
    while(ans[x]!=ans[y]){
        if(fl[ans[x]]<fl[ans[y]]) swap(x,y);
        mx=max(mx,query_max(1,1,n,dfn[ans[x]],dfn[x]));
        x=fa[ans[x]];
    }
    if(dfn[x]>dfn[y]) swap(x,y);
    mx=max(mx,query_max(1,1,n,dfn[x],dfn[y]));
    return mx;
}
void solve(int k,int a,int b){
    if(k==1) printf("%d\n",find_max(a,b));
    if(k==2) printf("%d\n",find_sum(a,b));
    if(k==3) update(1,1,n,dfn[a],b);
}
int main(){
    int u,v,ro;
    scanf("%d",&n);
    for(int i=1;i<n;i++) scanf("%d%d",&u,&v),ins(u,v);
    dfs1(1); dfs2(1,1);
    for(int i=1;i<=n;i++) scanf("%d",&W[i]),update(1,1,n,dfn[i],W[i]);
    scanf("%d",&q);
    while(q--){
        char ch=getchar(); int k;
        int x1=0,f1=1,x2=0,f2=1;
        while(ch<'0'||ch>'9'){
            if(ch=='X') k=1; if(ch=='U') k=2;
            if(ch=='H') k=3; if(ch=='-') f1=-1;
            ch=getchar();
        }
        while(ch>='0'&&ch<='9') x1=x1*10+ch-'0',ch=getchar();
        while(ch<'0'||ch>'9') f2=ch=='-'?-1:1,ch=getchar();
        while(ch>='0'&&ch<='9') x2=x2*10+ch-'0',ch=getchar();
        solve(k,x1*f1,x2*f2);
    }
    return 0;
}

————————————————————————————————————

来自PaperCloud的博客,未经允许,请勿转载,谢谢。

posted @ 2017-07-08 16:20  PaperCloud  阅读(219)  评论(0编辑  收藏  举报