[SDOI2011] 染色(Luogu 2486)
题目描述
输入输出格式
输入格式:
输出格式:
对于每个询问操作,输出一行答案。
输入输出样例
输出样例#1:
3 1 2
留一点自己思考的时间.
下面开始讲解.
题目大概题意是给出一颗树,树上每个节点都有一个颜色,让你求出一条路径上的颜色段数量,并需要对树进行修改.
看到树,很容易会想到是数据结构的题目.
需要对树上节点信息进行更新,并做到实时查询,这里采用了树链剖分.
树链剖分可以看我之前的讲解:树链剖分.
那么问题就转化成了记录一个区间的颜色段数量.如果用线段树来记录,那么可以直接存下一个节点的颜色段数量.但是当遇到区间的合并时,就需要判断两个合并的区间相连接的部分是否颜色一样. 如果一样,就将合并后两个区间颜色段数量相加再减一,否则直接相加.
但是当树剖部分中有一个跳链的步骤,从一条链跳到另一条链时就出现了两个区间需要合并的问题.为了解决这个问题,可以写一个函数对线段树中的端点的颜色进行单点查询.然后判断合并.
下面是代码:
#include<bits/stdc++.h> #define mid (left+right>>1) #define ll(x) (x<<1) #define rr(x) (x<<1|1) using namespace std; const int inf=2147483647; const int N=1000000; int n,m,cn[N+10]; int fa[N+10],id[N+10],son[N+10],dep[N+10],cs[N+10],top[N+10],size[N+10],idx=0; int sum[N*4+10],lazy[N*4+10],lc[N*4+10],rc[N*4+10]; int cnt=1,last[N+10]; struct edge{ int to,next; }e[N+10]; int gi(){ int ans=0,f=1;char i=getchar(); while(i<'0'||i>'9'){if(i=='-')f=-1;i=getchar();} while(i>='0'&&i<='9'){ans=ans*10+i-'0';i=getchar();} return ans*f; } void add(int x,int y){ e[++cnt].to=y; e[cnt].next=last[x]; last[x]=cnt; } void dfs1(int x,int deep,int father){ dep[x]=deep; fa[x]=father; int maxson=-1; for(int i=last[x];i;i=e[i].next){ int to=e[i].to; if(to!=father){ dfs1(to,deep+1,x); size[x]+=size[to]; if(maxson<size[to]){ son[x]=to; maxson=size[to]; } } } } void dfs2(int x,int tp){ top[x]=tp; id[x]=++idx; cs[idx]=cn[x]; if(!son[x]) return; dfs2(son[x],tp); for(int i=last[x];i;i=e[i].next){ int to=e[i].to; if(to!=son[x]&&to!=fa[x]) dfs2(to,to); } } void pushdown(int root,int left,int right){ lazy[ll(root)]=lazy[rr(root)]=lazy[root]; lc[ll(root)]=lc[rr(root)]=lazy[root]; rc[ll(root)]=rc[rr(root)]=lazy[root]; sum[ll(root)]=sum[rr(root)]=1; lazy[root]=0; } void pushup(int root,int left,int right){ lc[root]=lc[ll(root)]; rc[root]=rc[rr(root)]; int res=sum[ll(root)]+sum[rr(root)]; if(rc[ll(root)]==lc[rr(root)]) res--; sum[root]=res; } void build(int root,int left,int right){ if(left==right){ lc[root]=rc[root]=cs[left]; //cout<<root<<' '<<cs[left]<<endl;; sum[root]=1; return; } build(ll(root),left,mid); build(rr(root),mid+1,right); pushup(root,left,right); } void updata(int root,int left,int right,int l,int r,int col){ if(l<=left&&right<=r){ sum[root]=1; lazy[root]=col; lc[root]=rc[root]=col; return; } if(lazy[root]) pushdown(root,left,right); if(l<=mid) updata(ll(root),left,mid,l,r,col); if(mid<r) updata(rr(root),mid+1,right,l,r,col); pushup(root,left,right); } void cupdata(int a,int b,int val){ while(top[a]!=top[b]){ if(dep[top[a]]<dep[top[b]]) swap(a,b); updata(1,1,n,id[top[a]],id[a],val); a=fa[top[a]]; } if(id[a]>id[b]) swap(a,b); updata(1,1,n,id[a],id[b],val); } int qcol(int root,int left,int right,int node){ if(left==right) return lc[root]; if(lazy[root]) pushdown(root,left,right); if(node<=mid) return qcol(ll(root),left,mid,node); else return qcol(rr(root),mid+1,right,node); } int query(int root,int left,int right,int l,int r){ if(l<=left&&right<=r) return sum[root]; if(r<left||right<l) return 0; if(lazy[root]) pushdown(root,left,right); if(r<=mid) return query(ll(root),left,mid,l,r); else if(mid<l) return query(rr(root),mid+1,right,l,r); else{ int res=query(ll(root),left,mid,l,r)+query(rr(root),mid+1,right,l,r); if(lc[rr(root)]==rc[ll(root)]) res--; //cout<<res<<endl; return res; } } int cquery(int a,int b){ int res=0; while(top[a]!=top[b]){ if(dep[top[a]]<dep[top[b]]) swap(a,b); res+=query(1,1,n,id[top[a]],id[a]); int LC=qcol(1,1,n,id[top[a]]),RC=qcol(1,1,n,id[fa[top[a]]]); if(LC==RC) res--; a=fa[top[a]]; } if(id[a]>id[b]) swap(a,b); res+=query(1,1,n,id[a],id[b]); printf("%d\n",res); } int main(){ n=gi(); m=gi(); int x , y , val ;char flag; for(int i=1;i<=n;i++) cn[i]=gi(); for(int i=1;i<=n;i++) size[i]=1; for(int i=1;i<n;i++){ x=gi(); y=gi(); add(x,y); add(y,x); } dfs1(1,1,-1); dfs2(1,1); build(1,1,n); //for(int i=1;i<=n;i++) cout<<i<<' '<<cs[i]<<endl; for(int i=1;i<=m;i++){ cin>>flag; if(flag=='C'){ x=gi(); y=gi(); val=gi(); cupdata(x,y,val); } if(flag=='Q'){ x=gi(); y=gi(); cquery(x,y); } } return 0; }