BZOJ 2243 染色(树链剖分&线段树)

题目链接

题意

给定一棵有n个节点的无根树和m个操作,操作有2类:
  • 将节点a到节点b路径上所有点都染成颜色c;
  • 询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段), 如“112221”由3段组成:“11”、“222”和“1”。 请你写一个程序依次完成这m个操作。

题解

  • 先考虑将操作放到序列上,即对一段连续序列涂色,并询问区间颜色段数。这个问题用线段树即可很简单的解决
  • 现在将操作放到树上,那么我们只需要对这棵树进行树链剖分即可转换成对序列操作
查看代码
#include <bits/stdc++.h>
using namespace std;
#define _for(i,a,b) for(int i = (a);i <= (b);++i)
typedef long long ll;
const int maxn = 1e5+5;
const int mod = 1e9+7;
ll qpow(ll a,ll b){ll res = 1;for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
struct graph
{
    int head[maxn],nxt[maxn<<1],to[maxn<<1],w[maxn<<1],sz;
    void init(){memset(head,-1,sizeof(head));}
    graph(){init();}
    void push(int a,int b,int c=0){nxt[sz]=head[a],to[sz]=b,w[sz]=c,head[a]=sz++;}
    int& operator[](const int a){return to[a];}
}g;

int size[maxn],son[maxn],top[maxn],dfn[maxn],rnk[maxn],dep[maxn],fa[maxn],tot;

void dfs1(int now,int pre)
{
    dep[now] = dep[pre]+1;
    fa[now] = pre;
    size[now] = 1;
    for(int i = g.head[now];~i;i = g.nxt[i]){
        if(g[i]==pre)continue;
        dfs1(g[i],now);
        size[now] += size[g[i]];
        if(size[g[i]]>size[son[now]])son[now] = g[i];
    }

}

void dfs2(int now,int tp)
{
    top[now] = tp;
    dfn[++tot] = now;
    rnk[now] = tot;
    if(son[now])dfs2(son[now],tp);
    for(int i = g.head[now];~i;i = g.nxt[i]){
        if(g[i]==fa[now]||g[i]==son[now])continue;
        dfs2(g[i],g[i]);
    }
}
struct node
{
    int lc,rc,sum;
    node(){}
    node(int _l,int _r,int _sum){
        lc = _l,rc = _r,sum = _sum;
    }
};
int color[maxn];
struct Segment_tree
{
    node tree[maxn<<2];
    int lazy[maxn<<2];
    void build(int root,int l,int r){
        if(l==r){
            tree[root].lc = color[dfn[l]];
            tree[root].rc = color[dfn[r]];
            tree[root].sum = 1;
            return;
        }
        int mid = l+r>>1;
        build(root<<1,l,mid);
        build(root<<1|1,mid+1,r);
        tree[root] = pushup(tree[root<<1],tree[root<<1|1]);
    }
    void pushdown(int root){
        if(lazy[root]){
            tree[root<<1].lc = lazy[root];
            tree[root<<1].rc = lazy[root];
            tree[root<<1].sum = 1;
            tree[root<<1|1].lc = lazy[root];
            tree[root<<1|1].rc = lazy[root];
            tree[root<<1|1].sum = 1;
            lazy[root<<1] = lazy[root<<1|1] = lazy[root];
            lazy[root] = 0;
        }
    }
    node pushup(node a,node b){
        if(a.sum==0)return b;
        if(b.sum==0)return a;
        node tmp;
        tmp.lc = a.lc;
        tmp.rc = b.rc;
        tmp.sum = a.sum+b.sum;
        if(a.rc==b.lc)tmp.sum--;
        return tmp;
    }
    void modify(int root,int l,int r,int ml,int mr,int col){

        if(l >= ml&&r <= mr){
            tree[root].lc = col;
            tree[root].rc = col;
            lazy[root] = col;
            tree[root].sum = 1;
            return;
        }
        pushdown(root);
        int mid = l+r>>1;
        if(mid>=ml)modify(root<<1,l,mid,ml,mr,col);
        if(mr>mid)modify(root<<1|1,mid+1,r,ml,mr,col);
        tree[root] = pushup(tree[root<<1],tree[root<<1|1]);
    }
    node query(int root,int l,int r,int ql,int qr){       
        if(l >= ql&&r <= qr)return tree[root];
        pushdown(root);
        int mid = l+r>>1;
        node tmp;
        if(mid<ql){
            tmp = query(root<<1|1,mid+1,r,ql,qr);
        }
        else if(qr<=mid)tmp = query(root<<1,l,mid,ql,qr);
        else tmp = pushup(query(root<<1,l,mid,ql,qr),query(root<<1|1,mid+1,r,ql,qr));
        return tmp;
    }
}sg;
int n;
void query(int l,int r)
{
    node tmp1,tmp2;
    tmp1.sum = 0;
    tmp2.sum = 0;
    int x = l,y = r;
    while(top[x]!=top[y]){       
        if(dep[top[x]]>dep[top[y]]){
            node a = sg.query(1,1,n,rnk[top[x]],rnk[x]);
            tmp1 = sg.pushup(a,tmp1);
            x = fa[top[x]];
        }
        else{
            node b = sg.query(1,1,n,rnk[top[y]],rnk[y]);
            tmp2 = sg.pushup(b,tmp2);
            y = fa[top[y]];
        }
    }
    if(dep[x]>dep[y]){
        node a = sg.query(1,1,n,rnk[y],rnk[x]);
        tmp1 = sg.pushup(a,tmp1);
    }
    else {
        node b = sg.query(1,1,n,rnk[x],rnk[y]);
        tmp2 = sg.pushup(b,tmp2);
    }
    int ans = tmp1.sum+tmp2.sum;
    if(tmp1.sum==0||tmp2.sum==0);
    else if(tmp1.lc==tmp2.lc)ans--;
    printf("%d\n",ans);
}

void modify(int l,int r,int c)
{
    int x = l,y = r;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        sg.modify(1,1,n,rnk[top[x]],rnk[x],c);
        x = fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    sg.modify(1,1,n,rnk[x],rnk[y],c);
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("simple.in","r",stdin);
    freopen("simple.out","w",stdout);
#endif
    int m;
    scanf("%d%d",&n,&m);
    for(int i = 1;i <= n;++i)scanf("%d",&color[i]);
    for(int i = 1,a,b;i < n;++i){
        scanf("%d%d",&a,&b);
        g.push(a,b);
        g.push(b,a);
    }
    dfs1(1,0);
    dfs2(1,1);
    sg.build(1,1,n);
    while(m--){
        char opt;
        scanf(" %c",&opt);
        if(opt=='Q'){
            int l,r;
            scanf("%d%d",&l,&r);
            query(l,r);
        }
        else {
            int l,r,c;
            scanf("%d%d%d",&l,&r,&c);
            modify(l,r,c);
        }
    }
    return 0;
}
posted @ 2020-10-13 17:37  tryatry  阅读(98)  评论(0编辑  收藏  举报