bzoj1036: [ZJOI2008]树的统计Count 树链剖分+线段树

 

入门题 + 熟悉代码 

/**************************************************************
    Problem: 1036
    User: 96655
    Language: C++
    Result: Accepted
    Time:2472 ms
    Memory:5536 kb
****************************************************************/
 
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<iostream>
#include<cstdlib>
#include<queue>
#include<map>
#include<set>
#include<vector>
#include<cmath>
#include<stack>
#include<utility>
using namespace std;
const int maxn=30005;
struct Edge
{
    int v,next;
} edge[maxn*2];
int va[maxn],p,head[maxn],clk,n;
void init()
{
    memset(head,-1,sizeof(head));
    p=clk=0;
}
void addedge(int u,int v)
{
    edge[p].v=v;
    edge[p].next=head[u];
    head[u]=p++;
}
int id[maxn],sz[maxn],dep[maxn],fa[maxn],son[maxn],top[maxn];
int xx[maxn];
void dfs1(int u,int f,int d)
{
    dep[u]=d;
    fa[u]=f;
    sz[u]=1;
    son[u]=-1;
    for(int i=head[u]; ~i; i=edge[i].next)
    {
        int v=edge[i].v;
        if(v==f)continue;
        dfs1(v,u,d+1);
        sz[u]+=sz[v];
        if(son[u]==-1||sz[v]>sz[son[u]])
            son[u]=v;
    }
}
void dfs2(int u,int tp)
{
    id[u]=++clk;
    xx[id[u]]=va[u];
    top[u]=tp;
    if(son[u]!=-1)dfs2(son[u],tp);
    for(int i=head[u]; ~i; i=edge[i].next)
    {
        int v=edge[i].v;
        if(v==fa[u]||v==son[u])continue;
        dfs2(v,v);
    }
}
int sum[maxn*4],maxv[maxn*4];
void pushup(int rt)
{
    sum[rt]=sum[rt*2]+sum[rt*2+1];
    maxv[rt]=max(maxv[rt*2],maxv[rt*2+1]);
}
void build(int rt,int l,int r)
{
    if(l==r)
    {
        sum[rt]=maxv[rt]=xx[l];
        return;
    }
    int m=(l+r)>>1;
    build(rt*2,l,m);
    build(rt*2+1,m+1,r);
    pushup(rt);
}
void update(int rt,int l,int r,int pos,int c)
{
    if(l==r)
    {
        maxv[rt]=sum[rt]=c;
        return;
    }
    int m=(l+r)>>1;
    if(pos<=m)update(rt*2,l,m,pos,c);
    else update(rt*2+1,m+1,r,pos,c);
    pushup(rt);
}
int query1(int rt,int l,int r,int x,int y)
{
    if(x<=l&&r<=y)
    {
        return sum[rt];
    }
    int ans=0;
    int m=(l+r)>>1;
    if(x<=m)ans+=query1(rt*2,l,m,x,y);
    if(y>m)ans+=query1(rt*2+1,m+1,r,x,y);
    return ans;
}
int query2(int rt,int l,int r,int x,int y)
{
    if(x<=l&&r<=y)
    {
        return maxv[rt];
    }
    int ans=-99999;
    int m=(l+r)>>1;
    if(x<=m)ans=max(ans,query2(rt*2,l,m,x,y));
    if(y>m)ans=max(ans,query2(rt*2+1,m+1,r,x,y));
    return ans;
}
int getsum(int u,int v)
{
    int ans=0;
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])
            swap(u,v);
        ans+=query1(1,1,n,id[top[u]],id[u]);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    ans+=query1(1,1,n,id[u],id[v]);
    return ans;
}
int getmax(int u,int v)
{
    int ans=-99999;
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])
            swap(u,v);
        ans=max(ans,query2(1,1,n,id[top[u]],id[u]));
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    ans=max(ans,query2(1,1,n,id[u],id[v]));
    return ans;
}
int main()
{
    while(~scanf("%d",&n))
    {
        init();
        for(int i=1; i<n; ++i)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            addedge(u,v);
            addedge(v,u);
        }
        for(int i=1; i<=n; i++)
            scanf("%d",&va[i]);
        dfs1(1,1,0);
        dfs2(1,1);
        build(1,1,n);
        int q;
        scanf("%d",&q);
        while(q--)
        {
            char s[20];
            int x,y;
            scanf("%s%d%d",s,&x,&y);
            if(s[0]=='Q')
            {
                if(s[1]=='M') printf("%d\n",getmax(x,y));
                else printf("%d\n",getsum(x,y));
            }
            else update(1,1,n,id[x],y);
        }
    }
    return 0;
}
View Code

 

posted @ 2015-11-02 21:43  shuguangzw  阅读(146)  评论(0编辑  收藏  举报