bzoj2243 [SDOI2011]染色 (树链剖分+线段树)
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
思路:先用两次dfs进行树链剖分,把重链和轻链都找出来,计算出每一个节点对应的线段树的位置,然后在线段树上维护4个值,分别是cl表示线段最左端的颜色,cr表示线段最右端的颜色,num表示线段中不同颜色的区间数,flag是lazy标记,表示这段区间是不是只有一种颜色。所以修改操作就只要普通的修改就行了,这题重点是询问操作,我们可以用两个值pre1,pre2表示之前修改的链的右端点的颜色,初始化为-1,pre1表示u到lca上的链的修改,pre2表示v到lca上的链的修改,那么我们每次询问一条链的颜色的时候,求出这条线段的左端点和右端点的颜色以及这条线段中不同颜色的区间个数,如果左端点的颜色和pre1相同,那么总的颜色区间个数还要减去1.
#include<iostream>
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<vector>
#include<map>
#include<set>
#include<string>
#include<bitset>
#include<algorithm>
using namespace std;
#define lson th<<1
#define rson th<<1|1
typedef long long ll;
typedef long double ldb;
#define inf 999999999
#define pi acos(-1.0)
#define maxn 100010
struct edge{
int to,next;
}e[2*maxn];
int pos,tot,Lc,Rc;
int top[maxn],son[maxn],fa[maxn],dep[maxn],num[maxn],p[maxn],a[maxn],c[maxn],first[maxn];
void dfs1(int u,int pre,int deep)
{
int i,j,v;
dep[u]=deep;
fa[u]=pre;
num[u]=1;
for(i=first[u];i!=-1;i=e[i].next){
v=e[i].to;
if(v==pre)continue;
dfs1(v,u,deep+1);
num[u]+=num[v];
if(son[u]==-1 || num[son[u] ]<num[v])son[u]=v;
}
}
void dfs2(int u,int tp)
{
int i,j,v;
top[u]=tp;
if(son[u]!=-1){
p[u]=++pos;c[pos]=a[u];dfs2(son[u],tp);
}
else{
p[u]=++pos;c[pos]=a[u];return;
}
for(i=first[u];i!=-1;i=e[i].next){
v=e[i].to;
if(v==son[u] || v==fa[u])continue;
dfs2(v,v);
}
}
struct node{
int l,r,num,cl,cr,flag;
}b[4*maxn];
void pushup(int th)
{
b[th].cl=b[lson].cl;
b[th].cr=b[rson].cr;
b[th].num=b[lson].num+b[rson].num;
if(b[lson].cr==b[rson].cl)b[th].num--;
}
void pushdown(int th)
{
if(b[th].flag==1){
b[lson].flag=b[rson].flag=1;
b[lson].cl=b[lson].cr=b[rson].cl=b[rson].cr=b[th].cl;
b[lson].num=b[rson].num=1;
b[th].flag=-1;
}
}
void build(int l,int r,int th)
{
int mid;
b[th].l=l;b[th].r=r;
b[th].flag=-1;
if(l==r){b[th].cl=b[th].cr=c[l];b[th].num=1;return;}
mid=(b[th].l+b[th].r)/2;
build(l,mid,lson);build(mid+1,r,rson);
pushup(th);
}
void update(int l,int r,int color,int th)
{
int mid;
if(b[th].l==l && b[th].r==r){
b[th].num=1;b[th].flag=1;b[th].cl=b[th].cr=color;return;
}
pushdown(th);
mid=(b[th].l+b[th].r)/2;
if(r<=mid)update(l,r,color,lson);
else if(l>mid)update(l,r,color,rson);
else{update(l,mid,color,lson);update(mid+1,r,color,rson);}
pushup(th);
}
int question(int l,int r,int th,int L,int R)
{
int mid,num;
if(b[th].l==L)Lc=b[th].cl;
if(b[th].r==R)Rc=b[th].cr;
if(b[th].l==l && b[th].r==r){
return b[th].num;
}
pushdown(th);
mid=(b[th].l+b[th].r)/2;
if(r<=mid)return question(l,r,lson,L,R);
else if(l>mid)return question(l,r,rson,L,R);
else{
num=question(l,mid,lson,L,R)+question(mid+1,r,rson,L,R);
if(b[lson].cr==b[rson].cl)num--;
return num;
}
}
int solve(int u,int v)
{
int f1=top[u],f2=top[v];
int num=0,pre1,pre2;
pre1=pre2=-1;
while(f1!=f2){
if(dep[f1]<dep[f2]){swap(pre1,pre2);swap(f1,f2);swap(u,v);}
num+=question(p[f1],p[u],1,p[f1],p[u]);
if(pre1==Rc)num--;
pre1=Lc;u=fa[f1];f1=top[u];
}
if(dep[u]<dep[v]){swap(pre1,pre2);swap(u,v);}
num+=question(p[v],p[u],1,p[v],p[u]);
if(Rc==pre1)num--;
if(Lc==pre2)num--;
return num;
}
void gengxin(int u,int v,int value)
{
int f1=top[u],f2=top[v];
while(f1!=f2){
if(dep[f1]<dep[f2]){swap(f1,f2);swap(u,v);}
update(p[f1],p[u],value,1);
u=fa[f1];
f1=top[u];
}
if(dep[u]<dep[v])swap(u,v);
update(p[v],p[u],value,1);
}
void add(int u,int v)
{
tot++;
e[tot].next=first[u];e[tot].to=v;
first[u]=tot;
}
int main()
{
int i,j,n,m,u,v,f,g,h;
char s[10];
while(scanf("%d%d",&n,&m)!=EOF)
{
memset(first,-1,sizeof(first));
memset(son,-1,sizeof(son));
pos=0;tot=0;
for(i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for(i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
dfs1(1,0,1);dfs2(1,1);build(1,pos,1);
for(i=1;i<=m;i++){
scanf("%s",s);
if(s[0]=='Q'){
scanf("%d%d",&f,&g);
printf("%d\n",solve(f,g));
}
else{
scanf("%d%d%d",&f,&g,&h);
gengxin(f,g,h);
}
}
}
return 0;
}