P2486 [SDOI2011]染色 题解
基本上所有树剖题都可以用LCT维护。两种思路,一种是直接暴力地维护每个点表示的区间的左右端点颜色和颜色段数,另一种是把连接同种颜色的边边权设为 \(0\),不同种颜色设为 \(1\),然后维护路径和即可。
一个注意的点在于LCT中的 \(tag\) 是对所有跟左右有关的值进行取反的标记。一定不要只交换了左右儿子,忘了改左右颜色。
点击查看代码
#include<iostream>
#include<cstdio>
using namespace std;
const int N=1e5+13;
struct Edge{int v,nxt;}e[N<<1];
struct data{
int sum,lc,rc;
data operator +(const data &a)const{
data ans;
ans.lc=lc,ans.rc=a.rc;
ans.sum=sum+a.sum-(rc==a.lc);
return ans;
}
};
struct SegTree{int l,r;data x;bool set;}t[N<<2];
int n,m,tot,cnt,h[N],a[N],b[N];
int fa[N],siz[N],dep[N],son[N],top[N],id[N];
inline void add(int u,int v){e[++tot]=(Edge){v,h[u]};h[u]=tot;}
void dfs1(int u,int f,int deep){
fa[u]=f,siz[u]=1,dep[u]=deep;
int maxson=0;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==f) continue;
dfs1(v,u,deep+1);
siz[u]+=siz[v];
if(siz[v]>maxson) maxson=siz[v],son[u]=v;
}
}
void dfs2(int u,int topf){
top[u]=topf,id[u]=++cnt,b[cnt]=a[u];
if(!son[u]) return;
dfs2(son[u],topf);
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].v;
if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
}
#define ls p<<1
#define rs p<<1|1
#define mid ((t[p].l+t[p].r)>>1)
inline void refresh(int p){t[p].x=t[ls].x+t[rs].x;}
void build(int p,int l,int r){
t[p].l=l,t[p].r=r;
if(l==r){t[p].x=(data){1,b[l],b[l]};return;}
build(ls,l,mid);build(rs,mid+1,r);
refresh(p);
}
inline void pushup(int p,int k){
t[p].x=(data){1,k,k};
t[p].set=1;
}
inline void pushdown(int p){
if(!t[p].set) return;
pushup(ls,t[p].x.lc);
pushup(rs,t[p].x.lc);
t[p].set=0;
}
void update(int p,int l,int r,int k){
if(l<=t[p].l&&t[p].r<=r) return pushup(p,k);
pushdown(p);
if(l<=mid) update(ls,l,r,k);
if(r>mid) update(rs,l,r,k);
refresh(p);
}
data query(int p,int l,int r){
if(l<=t[p].l&&t[p].r<=r) return t[p].x;
pushdown(p);
if(r<=mid) return query(ls,l,r);
if(l>mid) return query(rs,l,r);
return query(ls,l,r)+query(rs,l,r);
}
inline void t_update(int u,int v,int w){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(1,id[top[u]],id[u],w);
u=fa[top[u]];
}
if(id[u]>id[v]) swap(u,v);
update(1,id[u],id[v],w);
}
inline int t_query(int u,int v){
data res1,res2;bool flag1=0,flag2=0;int res;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]){
if(!flag2) res2=query(1,id[top[v]],id[v]),flag2=1;
else res2=query(1,id[top[v]],id[v])+res2;
v=fa[top[v]];
}
else{
if(!flag1) res1=query(1,id[top[u]],id[u]),flag1=1;
else res1=query(1,id[top[u]],id[u])+res1;
u=fa[top[u]];
}
}
if(id[u]>id[v]){
if(!flag1) res1=query(1,id[v],id[u]);
else res1=query(1,id[v],id[u])+res1;
if(!flag2) res=res1.sum;
else res=res1.sum+res2.sum-(res1.lc==res2.lc);
}
else{
if(!flag2) res2=query(1,id[u],id[v]);
else res2=query(1,id[u],id[v])+res2;
if(!flag1) res=res2.sum;
else res=res1.sum+res2.sum-(res1.lc==res2.lc);
}
return res;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) scanf("%d",&a[i]);
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs1(1,0,0),dfs2(1,1),build(1,1,n);
for(int i=1,x,y,z;i<=m;++i){
char c;cin>>c;
if(c=='C'){
scanf("%d%d%d",&x,&y,&z);
t_update(x,y,z);
}
else{
scanf("%d%d",&x,&y);
printf("%d\n",t_query(x,y));
}
}
return 0;
}