bzoj2243 [SDOI2011]染色 动态树
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 110000 int pre[N],ch[N][2]; int e[N],ne[N*2],v[N*2]; int nn,m; int col[N]; int lc[N],sm[N],rc[N],num[N]; int rt[N],n; int qu[N],he,bo; void init(){ lc[0]=rc[0]=-1; num[0]=0; for(int i=1;i<=n;i++){ rt[i]=1; num[i]=1; lc[i]=rc[i]=col[i]; } } void up(int x){ int l=ch[x][0],r=ch[x][1]; num[x]=num[l]+num[r]+1; if(rc[l]==col[x])num[x]--; if(lc[r]==col[x])num[x]--; if(l)lc[x]=lc[l]; else lc[x]=col[x]; if(r)rc[x]=rc[r]; else rc[x]=col[x]; } void sam(int x,int y){ if(!x)return; num[x]=sm[x]=1; col[x]=lc[x]=rc[x]=y; } void down(int x){ if(sm[x]){ sm[x]=0; sam(ch[x][0],col[x]); sam(ch[x][1],col[x]); } } void rotate(int x){ int y=pre[x],z=pre[y]; down(y); down(x); int k=ch[y][0]==x; pre[ch[y][!k]=ch[x][k]]=y; pre[ch[x][k]=y]=x; pre[x]=z; if(!rt[y])ch[z][ch[z][1]==y]=x; else rt[y]=0,rt[x]=1; up(y); } void dfs(int x){ if(!rt[x])dfs(pre[x]); down(x); } void add(int x,int y){ ne[++nn]=e[x],e[x]=nn,v[nn]=y; } void splay(int x){ int y,z; dfs(x); for(down(x);!rt[x];){ y=pre[x],z=pre[y]; if(rt[y])rotate(x); else if((ch[z][1]==y)==(ch[y][1]==x))rotate(y),rotate(x); else rotate(x),rotate(x); } up(x); } int acess(int x){ int y=0; for(;x;y=x,x=pre[x]){ splay(x); rt[ch[x][1]]=1; ch[x][1]=y; rt[y]=0; up(x); } return y; } void C(int a,int b,int c){ acess(a); int y=acess(b); if(a==y){ sam(ch[a][1],c); col[a]=c; up(a); return; } splay(a); sam(a,c); sam(ch[y][1],c); col[y]=c; up(y); } int ask(int a,int b){ acess(a); int y=acess(b); if(a==y)return lc[ch[a][1]]==col[a]?num[ch[a][1]]:num[ch[a][1]]+1; splay(a); int ans=lc[ch[y][1]]==col[y]?num[ch[y][1]]+num[a]:num[ch[y][1]]+1+num[a]; if(lc[a]==col[y])ans--; return ans; } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n;i++){ scanf("%d",&col[i]) ; } for(int i=1;i<n;i++){ int a,b; scanf("%d%d",&a,&b); add(a,b); add(b,a); } int x,y; qu[he=bo=1]=1; while(he>=bo){ for(int i=e[x=qu[bo++]];i;i=ne[i]){ if((y=v[i])==1||pre[y])continue; pre[y]=x; qu[++he]=y; } } init(); for(int i=1;i<=m;i++){ char in; int a,b,c; scanf(" %c",&in); if(in=='C'){ scanf("%d%d%d",&a,&b,&c); C(a,b,c); }else{ scanf("%d%d",&a,&b); printf("%d\n",ask(a,b)); } } }