POJ3321 线段树,树状数组 建树经典题

这道题的意思是求一个结点的子树和(包含这个结点),操作只有两个,查询,和单点修改。容易想到用线段树来维护,怎么构建线段树是个问题。这道题正好学习了一下,dfs来遍历一遍,那么每一颗子树对应的新的结点的值都是连续的,我们遍历返回这个子树的最大值最小值,也就是要查找的范围。这道题在poj上提交有点小问题,用习惯了vector,居然被t掉了。。。换成了手写的邻接表就能过。。。下面上代码,一开始用的线段树写的,后来换成了树状数组(毕竟只需要单点修改):

1:线段树版

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <queue>
#include <vector>
#include <map>
#include <cstdlib>
#include <cstring>
#define ll long long
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define MAXN 100005
#define INF 1<<31
using namespace std;

vector <int> G[MAXN];
bool vis[MAXN];
int Sum,cnt,n,m,e[MAXN],q[MAXN],head[MAXN];

struct node{
    int val,next;
}edge[MAXN<<1];

struct Tree{
    int l,r;
    int sum;
}tree[MAXN<<2];

void add_edge(int u,int v){
    edge[cnt].val=v;
    edge[cnt].next=head[u];
    head[u]=cnt;
    cnt++;
}

void dfs(int u){
    ++Sum;
    q[u]=Sum;
    vis[u]=true;
    for(int i=head[u];i!=-1;i=edge[i].next){
        if(vis[edge[i].val])    continue;
        dfs(edge[i].val);
    }
    e[u]=Sum;
}

void pushup(int rt){
    tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum;
}

void Build(int rt,int l,int r){
    tree[rt].l=l;tree[rt].r=r;
    if(l==r){
        tree[rt].sum=1;
        return;
    }
    int mid = (l+r)>>1;
    Build(lson);
    Build(rson);
    pushup(rt);
}

void Modify(int rt,int x){
    int l=tree[rt].l,r=tree[rt].r;
    if(l==r){
        if(tree[rt].sum){
            tree[rt].sum=0;
        }
        else
            tree[rt].sum=1;
        return;
    }
    int mid=(l+r)>>1;
    if(x<=mid)
        Modify(rt<<1,x);
    else
        Modify(rt<<1|1,x);
    pushup(rt);
}

int Query(int rt,int l,int r){
    if(tree[rt].l==l && tree[rt].r==r){
        return tree[rt].sum;
    }
    int mid = (tree[rt].l+tree[rt].r)>>1;
    if(r<=mid)
        return Query(rt<<1,l,r);
    else if(l>mid)
        return Query(rt<<1|1,l,r);
    else
        return (Query(lson)+Query(rson));
}

int main()
{
    //freopen("test.in","r",stdin);
    while(scanf("%d",&n)!=EOF){
        cnt=0;
        int u,v;
        for(int i=1;i<=n;i++){
            head[i]=-1;
            vis[i]=false;
        }
        for(int i=0;i<n-1;i++){
            scanf("%d%d",&u,&v);
            add_edge(u,v);
            add_edge(v,u);
        }
        Sum = 0;
        dfs(1);
        Build(1,1,n);
        char s[5];
        int x;
        scanf("%d",&m);
        while(m--){
            scanf("%s%d",s,&x);
            if(s[0]=='C')
                Modify(1,q[x]);
            else
                printf("%d\n",Query(1,q[x],e[x]));
        }
    }
    return 0;
}

2:树状数组版:

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <queue>
#include <vector>
#include <map>
#include <cstdlib>
#include <cstring>
#define ll long long
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define MAXN 100005
#define INF 1<<31
using namespace std;

struct node{
    int val,next;
}edge[MAXN<<1];

int Sum,n,m,cnt,t[MAXN],s[MAXN],c[MAXN],a[MAXN],head[MAXN];
bool vis[MAXN];

void add_edge(int u,int v){
    edge[cnt].val=v;
    edge[cnt].next=head[u];
    head[u]=cnt;
    cnt++;
}

void dfs(int u){
    ++Sum;
    s[u]=Sum;
    vis[u]=true;
    for(int i=head[u];i!=-1;i=edge[i].next){
        if(vis[edge[i].val])    continue;
        dfs(edge[i].val);
    }
    t[u]=Sum;
}

int lowbit(int x){
    return x&(-x);
}

void update(int x){
    int val=1;
    if(a[x]){
        val = -1;
        a[x]=0;
    }
    else a[x]=1;
    while(x<=n){
        c[x]+=val;
        x+=lowbit(x);
    }
}

int getsum(int x){
    int sum = 0;
    while(x>0){
        sum+=c[x];
        x-=lowbit(x);
    }
    return sum;
}
int main()
{
   // freopen("test.in","r",stdin);
    while(scanf("%d",&n)!=EOF){
        int u,v;
        cnt=0;
        for(int i=1;i<=n;i++){
            a[i]=1;
            c[i]=lowbit(i);
            head[i]=-1;
            vis[i]=false;
        }
        for(int i=0;i<n-1;i++){
            scanf("%d%d",&u,&v);
            add_edge(u,v);
            add_edge(v,u);
        }
        Sum = 0;
        dfs(1);
        char str[5];
        int x;
        scanf("%d",&m);
        while(m--){
            scanf("%s%d",str,&x);
            if(str[0]=='C')
                update(s[x]);
            else
                printf("%d\n",(getsum(t[x])-getsum(s[x]-1)));
        }
    }
    return 0;
}



posted @ 2015-04-01 23:54  hqwhqwhq  阅读(144)  评论(0编辑  收藏  举报