BZOJ 1036: [ZJOI2008]树的统计Count(树链剖分)

第一次打树链剖分,也挺容易的嘛~~~,两次dfs后建线段树维护就行了~~~

CDOE:

1
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
#define maxn 30010
struct edges{
    int to,next;
}edge[maxn*2];
int next[maxn],l;
int addedge(int from,int to){
    edge[++l]=(edges){to,next[from]};
    next[from]=l;return 0;
}
int dep[maxn],par[maxn],si[maxn],ch[maxn];
bool b[maxn];
int dfs(int u){
    dep[u]=dep[par[u]]+1;
    b[u]=0;
    si[u]=1;
    for (int i=next[u];i;i=edge[i].next)
    if (b[edge[i].to]) { 
        par[edge[i].to]=u;
        dfs(edge[i].to);
        si[u]+=si[edge[i].to];
        if (si[ch[u]]<si[edge[i].to]) ch[u]=edge[i].to;
    }
    return 0;
}
int top[maxn],id[maxn],arr[maxn],pos[maxn];
int ind,num;
int heavy_edge(int x,bool flag){
    if (flag) top[++ind]=x;
    id[x]=ind;
    b[x]=0;
    arr[pos[x]=++num]=x;
    if (ch[x]) heavy_edge(ch[x],0);
    for (int i=next[x];i;i=edge[i].next)
    if (b[edge[i].to]){
        heavy_edge(edge[i].to,1);
    }
    return 0;
}
struct node{
    int l,r,Max,Sum;
}t[maxn*4];
int w[maxn];
int buildtree(int x,int l,int r){
    t[x].l=l;t[x].r=r;
    if (l==r) {t[x].Max=t[x].Sum=w[arr[l]];return 0;}
    buildtree(x<<1,l,(l+r)>>1);
    buildtree((x<<1)+1,((l+r)>>1)+1,r);
    t[x].Max=max(t[x<<1].Max,t[(x<<1)+1].Max);
    t[x].Sum=t[x<<1].Sum+t[(x<<1)+1].Sum;
    return 0;
}
int n;
int init(){
    scanf("%d",&n);
    for (int i=1;i<n;i++){
    int x,y;
    scanf("%d%d",&x,&y);
    addedge(x,y);addedge(y,x);
    }
    memset(b,true,sizeof(b));
    dfs(1);
    memset(b,true,sizeof(b));
    heavy_edge(1,1);
    for (int i=1;i<=n;i++) scanf("%d",w+i);
    buildtree(1,1,n);
    return 0;
}
int change(int x,int y){
    int l=t[x].l,r=t[x].r;
    if (y<l||y>r) return 0;
    if (l==r) {t[x].Max=t[x].Sum=w[arr[y]];return 0;}
    change(x<<1,y);change((x<<1)+1,y);
    t[x].Max=max(t[x<<1].Max,t[(x<<1)+1].Max);
    t[x].Sum=t[x<<1].Sum+t[(x<<1)+1].Sum;
    return 0;
}
int gmax(int x,int x1,int y1){
    int l=t[x].l,r=t[x].r;
    if (l>y1||r<x1) return -30002;
    if (l>=x1&&r<=y1) return (t[x].Max);
    return(max(gmax(x<<1,x1,y1),gmax((x<<1)+1,x1,y1)));
}
int gsum(int x,int x1,int y1){
    int l=t[x].l,r=t[x].r;
    if (l>y1||r<x1) return 0;
    if (l>=x1&&r<=y1) return (t[x].Sum);
    return(gsum(x<<1,x1,y1)+gsum((x<<1)+1,x1,y1));
}
int MAX(int l,int r){
    int ans=-30001;
    while (1){
    if (dep[top[id[l]]]<dep[top[id[r]]]) swap(l,r);
    if (id[l]!=id[r]) {ans=max(ans,gmax(1,pos[top[id[l]]],pos[l]));l=par[top[id[l]]];}
    else {if (dep[l]<dep[r]) swap(l,r);ans=max(ans,gmax(1,pos[r],pos[l]));return ans;}
    }
}
int SUM(int l,int r){
    int ans=0;
    while (1){
    if (dep[top[id[l]]]<dep[top[id[r]]]) swap(l,r);
    if (id[l]!=id[r]) {ans+=gsum(1,pos[top[id[l]]],pos[l]);l=par[top[id[l]]];}
    else {if (dep[l]<dep[r]) swap(l,r);ans+=gsum(1,pos[r],pos[l]);return ans;}
    }
}
int m;
int work(){
    scanf("%d",&m);
    for (int i=1;i<=m;i++) {
    int x,y;char s[10];
    scanf("%s%d%d",s,&x,&y);
    if (s[0]=='C') {w[x]=y;change(1,pos[x]);}
    if (s[1]=='M') {printf("%d\n",MAX(x,y));}
    if (s[1]=='S') {printf("%d\n",SUM(x,y));}
    }
    return 0;
}
int main(){
    init();
    work();
    return 0;
}

 

posted @ 2014-06-22 10:25  New_Godess  阅读(135)  评论(0编辑  收藏  举报