BZOJ 2243 [SDOI2011]染色 树链剖分+线段树
题意
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
分析
用线段树维护下区间左右端点的颜色,区间不同颜色段的个数,区间合并时接口的颜色相同要减一,
树剖一下,跑一遍就行了
Code
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
using namespace std;
typedef long long ll;
const int inf=1e9;
const int maxn=3e5+10;
int n,q;
int a[maxn];
vector<int>g[maxn];
int sz[maxn],f[maxn],d[maxn],top[maxn],son[maxn],id[maxn],p[maxn],tot;
int val[maxn<<2],c1[maxn<<2],c2[maxn<<2],tag[maxn<<2];
struct ppo{
int val,c1,c2;
};
void pp(int p){
val[p]=val[p<<1]+val[p<<1|1]-(c2[p<<1]==c1[p<<1|1]);
c1[p]=c1[p<<1];c2[p]=c2[p<<1|1];
}
void bd(int l,int r,int p){
tag[p]=-1;
if(l==r){
val[p]=1;c1[p]=c2[p]=a[id[l]];return;
}int mid=l+r>>1;
bd(lson);bd(rson);pp(p);
}
void pd(int l,int r,int p,int k){val[p]=1;c1[p]=c2[p]=k;tag[p]=k;};
void up(int dl,int dr,int l,int r,int p,int k){
if(l>=dl&&r<=dr){
val[p]=1;c1[p]=c2[p]=k;
tag[p]=k;return;
}int mid=l+r>>1;
if(~tag[p]){pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=-1;}
if(dl<=mid) up(dl,dr,lson,k);
if(dr>mid) up(dl,dr,rson,k);
pp(p);
}
ppo mer(ppo a,ppo b){
a.val=a.val+b.val-(a.c2==b.c1);
a.c2=b.c2;
return a;
}
ppo qy(int dl,int dr,int l,int r,int p){
if(l>=dl&&r<=dr){
return ppo{val[p],c1[p],c2[p]};
}int mid=l+r>>1;
if(~tag[p]){pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=-1;}
if(dr<=mid){
return qy(dl,dr,lson);
}else if(dl>mid){
return qy(dl,dr,rson);
}else{
return mer(qy(dl,dr,lson),qy(dl,dr,rson));
}
}
void dfs1(int u){
sz[u]=1;d[u]=d[f[u]]+1;
for(int i=0;i<g[u].size();i++){
int x=g[u][i];
if(x==f[u]) continue;
f[x]=u;dfs1(x);
sz[u]+=sz[x];
if(sz[x]>sz[son[u]]) son[u]=x;
}
}
void dfs2(int u,int t){
top[u]=t;p[u]=++tot;id[tot]=u;
if(son[u]) dfs2(son[u],t);
for(int i=0;i<g[u].size();i++){
int x=g[u][i];
if(x==f[u]||x==son[u]) continue;
dfs2(x,x);
}
}
void col(int x,int y,int k){
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
up(p[top[x]],p[x],1,n,1,k);
x=f[top[x]];
}
if(d[x]<d[y]) swap(x,y);
up(p[y],p[x],1,n,1,k);
}
int cal(int x,int y){
ppo a=ppo{0,-1,-1},b=ppo{0,-1,-1};
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]){
ppo res=qy(p[top[y]],p[y],1,n,1);
if(b.c1==-1) b=res;
else b=mer(res,b);
y=f[top[y]];
}else{
ppo res=qy(p[top[x]],p[x],1,n,1);
if(a.c1==-1) a=res;
else a=mer(res,a);
x=f[top[x]];
}
}
if(d[x]<d[y]){
ppo res=qy(p[x],p[y],1,n,1);
b=mer(res,b);
}else{
ppo res=qy(p[y],p[x],1,n,1);
a=mer(res,a);
}
int ret=a.val+b.val-(a.c1==b.c1);
return ret;
}
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for(int i=1,a,b;i<n;i++){
scanf("%d%d",&a,&b);
g[a].pb(b);g[b].pb(a);
}
dfs1(1);dfs2(1,1);bd(1,n,1);
while(q--){
int a,b,c;
char s[3];
scanf("%s",s);
if(s[0]=='C'){
scanf("%d%d%d",&a,&b,&c);
col(a,b,c);
}else{
scanf("%d%d",&a,&b);
printf("%d\n",cal(a,b));
}
}
return 0;
}