【ZJOI2008】树的统计

Description

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身


树链剖分模板题,这里写的很好。

简单补充一下,就是把一棵树剖成一条链,然后链上做线段树即可。

Code

#include<iostream>
#include<cstdio>
#include<cstdlib>
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define N 90010
#define inf 1000000000
using namespace std;
int w[N];
int to[N],next[N],last[N],num=0;
int dep[N],siz[N];
int top[N],son[N];
int fa[N];
int pos[N],z[N];
struct node{
    int s,mx;
}tr[N*4];
void link(int x,int y)
{
    num++;
    to[num]=y;
    next[num]=last[x];
    last[x]=num;
}
void find(int x)//寻找重边
{
    siz[x]=1;son[x]=0;
    for(int i=last[x];i;i=next[i])
    {
        int v=to[i];
        if(v!=fa[x])
        {
            dep[v]=dep[x]+1;
            fa[v]=x;
            find(v);
            if(siz[v]>siz[son[x]]) son[x]=v;
            siz[x]+=siz[v];
        }
    }
}
int cnt=0;
void dfs(int x,int t)//剖链的过程
{
    pos[x]=++cnt;
    top[x]=t;
    z[pos[x]]=x;
    if(son[x]) dfs(son[x],t);
    for(int i=last[x];i;i=next[i])
    {
        int v=to[i];
        if(v!=fa[x] && v!=son[x])
        dfs(v,v);
    }
}
void build(int v,int l,int r)
{
    if(l==r)
    {
        tr[v].s=tr[v].mx=w[z[l]];
        return;
    }
    int mid=(l+r)/2;
    build(v*2,l,mid);
    build(v*2+1,mid+1,r);
    tr[v].s=tr[v*2].s+tr[v*2+1].s;
    tr[v].mx=max(tr[v*2].mx,tr[v*2+1].mx);
}
void change(int v,int l,int r,int x,int p)
{
    if(l==r && l==x)
    {
        tr[v].s=tr[v].mx=p;
        return;
    }
    int mid=(l+r)/2;
    if(x<=mid) change(v*2,l,mid,x,p);
    else change(v*2+1,mid+1,r,x,p);
    tr[v].s=tr[v*2].s+tr[v*2+1].s;
    tr[v].mx=max(tr[v*2].mx,tr[v*2+1].mx);
}
int qsum(int v,int l,int r,int x,int y)
{
    if(l==x && r==y) return tr[v].s;
    int mid=(l+r)/2;
    if(y<=mid) return qsum(v*2,l,mid,x,y);
    else if(x>mid) return qsum(v*2+1,mid+1,r,x,y);
    else return qsum(v*2,l,mid,x,mid)+qsum(v*2+1,mid+1,r,mid+1,y);
}
int qmax(int v,int l,int r,int x,int y)
{
    if(l==x && r==y) return tr[v].mx;
    int mid=(l+r)/2;
    if(y<=mid) return qmax(v*2,l,mid,x,y);
    else if(x>mid) return qmax(v*2+1,mid+1,r,x,y);
    else return max(qmax(v*2,l,mid,x,mid),qmax(v*2+1,mid+1,r,mid+1,y));
}
int findmax(int u,int v)
{
    int f1=top[u],f2=top[v];
    int tmp=-inf;
    while(f1!=f2)
    {
        if(dep[f1]<dep[f2])
        {
            swap(f1,f2);
            swap(u,v);
        }
        tmp=max(tmp,qmax(1,1,cnt,pos[f1],pos[u]));
        u=fa[f1];
        f1=top[u];
    }
    if(dep[u]>dep[v]) swap(u,v); 
    return max(tmp,qmax(1,1,cnt,pos[u],pos[v]));
}
int findsum(int u,int v) 
{
    int f1=top[u],f2=top[v];
    int tmp=0;
    while(f1!=f2)
    {
        if(dep[f1]<dep[f2])
        {
            swap(f1,f2);
            swap(u,v);
        }
        tmp+=qsum(1,1,cnt,pos[f1],pos[u]);
        u=fa[f1];
        f1=top[u];
    }
    if(dep[u]>dep[v]) swap(u,v);
    return tmp+qsum(1,1,cnt,pos[u],pos[v]);
}
int main()
{
    int n;
    cin>>n;
    fo(i,1,n-1)
    {
        int x,y;
        scanf("%d %d",&x,&y);
        link(x,y);
        link(y,x);
    }
    fo(i,1,n) scanf("%d",&w[i]);
    find(1);
    dfs(1,1);
    build(1,1,cnt);
    int Q;
    cin>>Q;
    while(Q--)
    {
        char s[11];
        scanf("%s",s);
        if(s[0]=='C')
        {
            int x,t;
            scanf("%d %d",&x,&t);
            change(1,1,cnt,pos[x],t);
        }
        else
        {
            int x,y;
            scanf("%d %d",&x,&y);
            if(s[1]=='M') printf("%d\n",findmax(x,y));
            else printf("%d\n",findsum(x,y));
        }
    }
}
posted @ 2016-05-13 11:54  sadstone  阅读(35)  评论(0编辑  收藏  举报