[SDOI2011] 染色(Luogu 2486)

题目描述

输入输出格式

输入格式:

 

 

输出格式:

 

对于每个询问操作,输出一行答案。

 

输入输出样例

输入样例#1: 
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
输出样例#1:
3
1
2

 

 

留一点自己思考的时间.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

下面开始讲解.

题目大概题意是给出一颗树,树上每个节点都有一个颜色,让你求出一条路径上的颜色段数量,并需要对树进行修改.

看到树,很容易会想到是数据结构的题目.

需要对树上节点信息进行更新,并做到实时查询,这里采用了树链剖分.

树链剖分可以看我之前的讲解:树链剖分.

那么问题就转化成了记录一个区间的颜色段数量.如果用线段树来记录,那么可以直接存下一个节点的颜色段数量.但是当遇到区间的合并时,就需要判断两个合并的区间相连接的部分是否颜色一样. 如果一样,就将合并后两个区间颜色段数量相加再减一,否则直接相加.

但是当树剖部分中有一个跳链的步骤,从一条链跳到另一条链时就出现了两个区间需要合并的问题.为了解决这个问题,可以写一个函数对线段树中的端点的颜色进行单点查询.然后判断合并.

下面是代码:

#include<bits/stdc++.h>
#define mid (left+right>>1)
#define ll(x) (x<<1)
#define rr(x) (x<<1|1)
using namespace std;
const int inf=2147483647;
const int N=1000000;

int n,m,cn[N+10];
int fa[N+10],id[N+10],son[N+10],dep[N+10],cs[N+10],top[N+10],size[N+10],idx=0;
int sum[N*4+10],lazy[N*4+10],lc[N*4+10],rc[N*4+10];
int cnt=1,last[N+10];

struct edge{
  int to,next;
}e[N+10];

int gi(){
  int ans=0,f=1;char i=getchar();
  while(i<'0'||i>'9'){if(i=='-')f=-1;i=getchar();}
  while(i>='0'&&i<='9'){ans=ans*10+i-'0';i=getchar();}
  return ans*f;
}

void add(int x,int y){
  e[++cnt].to=y;
  e[cnt].next=last[x];
  last[x]=cnt;
}

void dfs1(int x,int deep,int father){
  dep[x]=deep; fa[x]=father; int maxson=-1;
  for(int i=last[x];i;i=e[i].next){
    int to=e[i].to;
    if(to!=father){
      dfs1(to,deep+1,x);
      size[x]+=size[to];
      if(maxson<size[to]){
    son[x]=to;
    maxson=size[to];
      }
    }
  }
}

void dfs2(int x,int tp){
  top[x]=tp; id[x]=++idx; cs[idx]=cn[x];
  if(!son[x]) return;
  dfs2(son[x],tp);
  for(int i=last[x];i;i=e[i].next){
    int to=e[i].to;
    if(to!=son[x]&&to!=fa[x])
      dfs2(to,to);
  }
}

void pushdown(int root,int left,int right){
  lazy[ll(root)]=lazy[rr(root)]=lazy[root];
  lc[ll(root)]=lc[rr(root)]=lazy[root];
  rc[ll(root)]=rc[rr(root)]=lazy[root];
  sum[ll(root)]=sum[rr(root)]=1;
  lazy[root]=0;
}

void pushup(int root,int left,int right){
  lc[root]=lc[ll(root)]; rc[root]=rc[rr(root)];
  int res=sum[ll(root)]+sum[rr(root)];
  if(rc[ll(root)]==lc[rr(root)]) res--;
  sum[root]=res;
}

void build(int root,int left,int right){
  if(left==right){
    lc[root]=rc[root]=cs[left];
    //cout<<root<<' '<<cs[left]<<endl;;
    sum[root]=1;
    return;
  }
  build(ll(root),left,mid);
  build(rr(root),mid+1,right);
  pushup(root,left,right);
}

void updata(int root,int left,int right,int l,int r,int col){
  if(l<=left&&right<=r){
    sum[root]=1; lazy[root]=col;
    lc[root]=rc[root]=col;
    return;
  }
  if(lazy[root]) pushdown(root,left,right);
  if(l<=mid) updata(ll(root),left,mid,l,r,col);
  if(mid<r) updata(rr(root),mid+1,right,l,r,col);
  pushup(root,left,right);
}

void cupdata(int a,int b,int val){
  while(top[a]!=top[b]){
    if(dep[top[a]]<dep[top[b]]) swap(a,b);
    updata(1,1,n,id[top[a]],id[a],val);
    a=fa[top[a]];
  }
  if(id[a]>id[b]) swap(a,b);
  updata(1,1,n,id[a],id[b],val);
}
 
int qcol(int root,int left,int right,int node){
  if(left==right) return lc[root];
  if(lazy[root]) pushdown(root,left,right);
  if(node<=mid) return qcol(ll(root),left,mid,node);
  else return qcol(rr(root),mid+1,right,node);
}

int query(int root,int left,int right,int l,int r){
  if(l<=left&&right<=r) return sum[root];
  if(r<left||right<l) return 0;
  if(lazy[root]) pushdown(root,left,right);
  if(r<=mid) return query(ll(root),left,mid,l,r);
  else if(mid<l) return query(rr(root),mid+1,right,l,r);
  else{
    int res=query(ll(root),left,mid,l,r)+query(rr(root),mid+1,right,l,r);
    if(lc[rr(root)]==rc[ll(root)]) res--;
    //cout<<res<<endl;
    return res;
  }
}

int cquery(int a,int b){
  int res=0;
  while(top[a]!=top[b]){
    if(dep[top[a]]<dep[top[b]]) swap(a,b);
    res+=query(1,1,n,id[top[a]],id[a]);
    int LC=qcol(1,1,n,id[top[a]]),RC=qcol(1,1,n,id[fa[top[a]]]);
    if(LC==RC) res--;
    a=fa[top[a]];
  }
  if(id[a]>id[b]) swap(a,b);
  res+=query(1,1,n,id[a],id[b]);
  printf("%d\n",res);
}

int main(){
  n=gi(); m=gi(); int x , y , val ;char flag;
  for(int i=1;i<=n;i++) cn[i]=gi();
  for(int i=1;i<=n;i++) size[i]=1;
  for(int i=1;i<n;i++){
    x=gi(); y=gi();
    add(x,y); add(y,x);
  }
  dfs1(1,1,-1); dfs2(1,1); build(1,1,n);
  //for(int i=1;i<=n;i++) cout<<i<<' '<<cs[i]<<endl;
  for(int i=1;i<=m;i++){
    cin>>flag;
    if(flag=='C'){
      x=gi(); y=gi(); val=gi();
      cupdata(x,y,val);
    }
    if(flag=='Q'){
      x=gi(); y=gi();
      cquery(x,y);
    }
  }
  return 0;
}

 

posted @ 2018-02-06 21:09  Brave_Cattle  阅读(176)  评论(0编辑  收藏  举报