BZOJ - 2243 染色 (树链剖分+线段树+区间合并)
线段树维护区间连续段个数即可。设lc为区间左端点颜色,rc为区间右端点颜色,则合并两区间的时候,如果左区间右端点和右区间左端点颜色相同,则连续段个数-1。
在树链上的区间合并可以定义一个结构体作为线段,分成左右两条链暴力合并。也可以考虑到树上的路径中每两个树链“断开”的地方必然有一个结点是另一个结点的祖先,因此如果top[u]的颜色与fa[top[u]]的颜色相同时答案-1即可。
树剖和线段树结合真容易把人搞晕啊,什么时候要用l,r,什么时候要用u,什么时候要用dfn[u],一定要分清楚~~
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1e5+10,inf=0x3f3f3f3f; 5 int hd[N],ne,n,k,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,a[N],cnt[N<<3],mk[N<<3],lc[N<<3],rc[N<<3]; 6 struct E {int v,nxt;} e[N<<1]; 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 8 void dfs1(int u,int f,int d) { 9 fa[u]=f,fa[u]=f,siz[u]=1,dep[u]=d; 10 for(int i=hd[u]; ~i; i=e[i].nxt) { 11 int v=e[i].v; 12 if(v==fa[u])continue; 13 dfs1(v,u,d+1),siz[u]+=siz[v]; 14 if(siz[v]>siz[son[u]])son[u]=v; 15 } 16 } 17 void dfs2(int u,int tp) { 18 top[u]=tp,dfn[u]=++tot,rnk[dfn[u]]=u; 19 if(!son[u])return; 20 dfs2(son[u],top[u]); 21 for(int i=hd[u]; ~i; i=e[i].nxt) { 22 int v=e[i].v; 23 if(v==fa[u]||v==son[u])continue; 24 dfs2(v,v); 25 } 26 } 27 #define ls (u<<1) 28 #define rs (u<<1|1) 29 #define mid ((l+r)>>1) 30 void pu(int u) {lc[u]=lc[ls],rc[u]=rc[rs],cnt[u]=cnt[ls]+cnt[rs]; if(rc[ls]==lc[rs])cnt[u]--;} 31 void pd(int u) {if(mk[u])lc[u]=rc[u]=mk[u],cnt[u]=1,mk[ls]=mk[rs]=mk[u],mk[u]=0;} 32 void build(int u=1,int l=1,int r=tot) { 33 if(l==r) {lc[u]=rc[u]=a[rnk[l]],cnt[u]=1; return;} 34 build(ls,l,mid),build(rs,mid+1,r),pu(u); 35 } 36 void upd(int L,int R,int x,int u=1,int l=1,int r=tot) { 37 pd(u); 38 if(l>=L&&r<=R) {mk[u]=x,pd(u); return;} 39 if(l>R||r<L)return; 40 upd(L,R,x,ls,l,mid),upd(L,R,x,rs,mid+1,r),pu(u); 41 } 42 int getcol(int p,int u=1,int l=1,int r=tot) { 43 pd(u); 44 if(l==r)return lc[u]; 45 return p<=mid?getcol(p,ls,l,mid):getcol(p,rs,mid+1,r); 46 } 47 int qry(int L,int R,int u=1,int l=1,int r=tot) { 48 pd(u); 49 if(l>=L&&r<=R)return cnt[u]; 50 if(l>R||r<L)return 0; 51 int t1=qry(L,R,ls,l,mid),t2=qry(L,R,rs,mid+1,r); 52 int ret=t1+t2; 53 if(t1&&t2&&rc[ls]==lc[rs])ret--; 54 return ret; 55 } 56 void change(int u,int v,int x) { 57 for(; top[u]!=top[v]; u=fa[top[u]]) { 58 if(dep[top[u]]<dep[top[v]])swap(u,v); 59 upd(dfn[top[u]],dfn[u],x); 60 } 61 if(dep[u]<dep[v])swap(u,v); 62 upd(dfn[v],dfn[u],x); 63 } 64 int ask(int u,int v) { 65 int ret=0; 66 for(; top[u]!=top[v]; u=fa[top[u]]) { 67 if(dep[top[u]]<dep[top[v]])swap(u,v); 68 ret+=qry(dfn[top[u]],dfn[u]); 69 if(getcol(dfn[top[u]])==getcol(dfn[fa[top[u]]]))ret--; 70 } 71 if(dep[u]<dep[v])swap(u,v); 72 ret+=qry(dfn[v],dfn[u]); 73 return ret; 74 } 75 int main() { 76 memset(hd,-1,sizeof hd),ne=0; 77 scanf("%d%d",&n,&k); 78 for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]++; 79 for(int i=1; i<n; ++i) { 80 int u,v; 81 scanf("%d%d",&u,&v); 82 addedge(u,v); 83 addedge(v,u); 84 } 85 tot=0,dfs1(1,0,0),dfs2(1,1),build(); 86 while(k--) { 87 char ch; 88 int a,b,c; 89 scanf(" %c%d%d",&ch,&a,&b); 90 if(ch=='Q')printf("%d\n",ask(a,b)); 91 else scanf("%d",&c),c++,change(a,b,c); 92 } 93 return 0; 94 }