P2486 [SDOI2011]染色 区间合并+树链剖分(加深对线段树的理解)

#include<bits/stdc++.h>
using namespace std;
const  int M=3e5+5;
struct node{
    int l,r,cnt,lazy;
    node(int l1=0,int r1=0,int cnt1=0,int lazy1=0):l(l1),r(r1),cnt(cnt1),lazy(lazy1){}
}tree[M<<2];
int fa[M],sz[M],deep[M],dfn[M],son[M],to[M],a[M],top[M],cnt,n;
char s[2];
vector<int>g[M];
void dfs1(int u,int from){
    fa[u]=from;
    sz[u]=1;
    deep[u]=deep[from]+1;
    for(int i=0;i<g[u].size();i++){

        int v=g[u][i];
        if(v!=from){
            dfs1(v,u);
            sz[u]+=sz[v];
            if(sz[v]>sz[son[u]])
                son[u]=v;
        }
        
    }
}
void dfs2(int u,int t){
    top[u]=t;
    dfn[u]=++cnt;
    to[cnt]=u;
    if(!son[u])
        return ;
    dfs2(son[u],t);
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i];
        if(v!=fa[u]&&v!=son[u])
            dfs2(v,v);
    }
}
void up(int root){
    tree[root].cnt=tree[root<<1].cnt+tree[root<<1|1].cnt;
    if(tree[root<<1].r==tree[root<<1|1].l)
        tree[root].cnt--;
    tree[root].l=tree[root<<1].l;
    tree[root].r=tree[root<<1|1].r;
}
void build(int root,int l,int r){
    tree[root].lazy=0;
    if(l==r){
        tree[root].l=tree[root].r=a[to[l]];
        tree[root].cnt=1;
        return ;
    }
    int midd=(l+r)>>1;
    build(root<<1,l,midd);
    build(root<<1|1,midd+1,r);
    up(root);
}
void pushdown(int root){
    tree[root<<1]=tree[root<<1|1]=node(tree[root].l,tree[root].r,1,tree[root].lazy);
    tree[root].lazy=0;
}
void update(int L,int R,int x,int root,int l,int r){
    if(L<=l&&r<=R){
        tree[root]=node(x,x,1,x);
        return ;
    }
    if(tree[root].lazy)
        pushdown(root);
    int midd=(l+r)>>1;
    if(L<=midd)
        update(L,R,x,root<<1,l,midd);
    if(R>midd)
        update(L,R,x,root<<1|1,midd+1,r);
    up(root);
}
void add(int u,int v ,int w){
    int fu=top[u],fv=top[v];
    while(fu!=fv){
        if(deep[fu]>=deep[fv])
            update(dfn[fu],dfn[u],w,1,1,n),u=fa[fu],fu=top[u];
        else
            update(dfn[fv],dfn[v],w,1,1,n),v=fa[fv],fv=top[v];
    }
    if(dfn[u]<=dfn[v])
        update(dfn[u],dfn[v],w,1,1,n);
    else
        update(dfn[v],dfn[u],w,1,1,n);
}
node meger(node a,node b){
    if(!a.cnt)
        return b;
    if(!b.cnt)
        return a;
    node ans=node(0,0,0,0);
    ans.cnt=a.cnt+b.cnt;
    if(a.r==b.l)
        ans.cnt--;
    ans.l=a.l;
    ans.r=b.r;
    return ans;
}
node query(int L,int R,int root,int l,int r){
    if(L<=l&&r<=R){
        return tree[root];
    }
    if(tree[root].lazy)
        pushdown(root);
    int midd=(l+r)>>1;
    node ans;
    if(L<=midd)
        ans=query(L,R,root<<1,l,midd);
    if(R>midd)
        ans=meger(ans,query(L,R,root<<1|1,midd+1,r));
    up(root);
    return ans;
}
int solve(int u,int v){
    node l,r;
    int fv=top[v],fu=top[u];
    while(fv!=fu){
        if(deep[fu]>=deep[fv])
            l=meger(query(dfn[fu],dfn[u],1,1,n),l),u=fa[fu],fu=top[u];
        else
            r=meger(query(dfn[fv],dfn[v],1,1,n),r),v=fa[fv],fv=top[v];
    }
    if(dfn[u]<=dfn[v])
        r=meger(query(dfn[u],dfn[v],1,1,n),r);
    else
        l=meger(query(dfn[v],dfn[u],1,1,n),l);
    swap(l.l,l.r);
    l=meger(l,r);
    return l.cnt;
}
int main(){
    int m;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d",&a[i]);
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }//cout<<"!!"<<endl;
    dfs1(1,1);
    dfs2(1,1);
    
    build(1,1,n);
    while(m--){
        int u,v,w;
        scanf("%s",s);
        if(s[0]=='Q'){
            scanf("%d%d",&u,&v);
            printf("%d\n",solve(u,v));
        }
        else{
            scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);
        }
    }
    return 0;
}
View Code

 

posted @ 2019-05-20 17:56  starve_to_death  阅读(127)  评论(0编辑  收藏  举报