[SDOI2011][BZOJ2243] 染色|线段树|树链剖分
2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 3583 Solved: 1362
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
好久没打树链剖分了……又写(抄)了一遍。
树链剖分+线段树区间操作就可过。
#include<iostream> #include<cstdio> #include<cstring> #include<cstdlib> #include<algorithm> #include<cmath> #include<vector> #include<set> #include<map> #include<queue> #define N 100005 using namespace std; int head[N],next[2*N],list[2*N]; int l[4*N],r[4*N],s[4*N],tag[4*N],size[N],lc[4*N],rc[4*N],c[N],fa[N][20],deep[N],id[N],belong[N]; int n,m,cnt,dfn; inline int read() { int a=0,f=1; char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();} while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();} return a*f; } inline void insert(int x,int y) { next[++cnt]=head[x]; head[x]=cnt; list[cnt]=y; } void dfs1(int x) { size[x]=1; for (int i=1;(1<<i)<=deep[x];i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for (int i=head[x];i;i=next[i]) { if (list[i]==fa[x][0]) continue; fa[list[i]][0]=x; deep[list[i]]=deep[x]+1; dfs1(list[i]); size[x]+=size[list[i]]; } } void dfs2(int x,int chain) { id[x]=++dfn; belong[x]=chain; int k=0; for (int i=head[x];i;i=next[i]) if (list[i]!=fa[x][0]&&size[list[i]]>size[k]) k=list[i]; if (!k) return; dfs2(k,chain); for (int i=head[x];i;i=next[i]) if (list[i]!=fa[x][0]&&list[i]!=k) dfs2(list[i],list[i]); } inline int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int t=deep[x]-deep[y]; for (int i=0;(1<<i)<=t;i++) if ((1<<i)&t) x=fa[x][i]; for (int i=18;i>=0;i--) if (fa[x][i]!=fa[y][i]) {x=fa[x][i]; y=fa[y][i];} return x==y?x:fa[x][0]; } void build_tree(int k,int ll,int rr) { l[k]=ll; r[k]=rr; s[k]=1; tag[k]=-1; if (ll==rr) return; int mid=(ll+rr)>>1; build_tree(k<<1,ll,mid); build_tree(k<<1|1,mid+1,rr); } void pushup(int k) { lc[k]=lc[k<<1]; rc[k]=rc[k<<1|1]; if (rc[k<<1]^lc[k<<1|1]) s[k]=s[k<<1]+s[k<<1|1]; else s[k]=s[k<<1]+s[k<<1|1]-1; } void pushdown(int k) { int tmp=tag[k]; tag[k]=-1; if (tmp==-1||l[k]==r[k]) return; s[k<<1]=s[k<<1|1]=1; tag[k<<1]=tag[k<<1|1]=tmp; lc[k<<1]=rc[k<<1]=tmp; lc[k<<1|1]=rc[k<<1|1]=tmp; } void change(int k,int x,int y,int v) { pushdown(k); if (l[k]==x&&r[k]==y) { lc[k]=rc[k]=v; s[k]=1; tag[k]=v; return; } int mid=(l[k]+r[k])>>1; if (mid>=y) change(k<<1,x,y,v); else if (mid<x) change(k<<1|1,x,y,v); else { change(k<<1,x,mid,v); change(k<<1|1,mid+1,y,v); } pushup(k); } int ask(int k,int x,int y) { pushdown(k); if (l[k]==x&&r[k]==y) return s[k]; int mid=(l[k]+r[k])>>1; if (mid>=y) return ask(k<<1,x,y); else if (mid<x) return ask(k<<1|1,x,y); else { int tmp=1; if (rc[k<<1]^lc[k<<1|1]) tmp=0; return ask(k<<1,x,mid)+ask(k<<1|1,mid+1,y)-tmp; } } int getc(int k,int x) { pushdown(k); if (l[k]==r[k])return lc[k]; int mid=(l[k]+r[k])>>1; if (x<=mid) return getc(k<<1,x); else return (getc(k<<1|1,x)); } inline void solvechange(int x,int f,int c) { while (belong[x]!=belong[f]) { change(1,id[belong[x]],id[x],c); x=fa[belong[x]][0]; } change(1,id[f],id[x],c); } inline int solvesum(int x,int f) { int sum=0; while (belong[x]!=belong[f]) { sum+=ask(1,id[belong[x]],id[x]); if (getc(1,id[belong[x]])==getc(1,id[fa[belong[x]][0]])) sum--; x=fa[belong[x]][0]; } sum+=ask(1,id[f],id[x]); return sum; } int main() { n=read(); m=read(); for (int i=1;i<=n;i++) c[i]=read(); for (int i=1;i<n;i++) { int u=read(),v=read(); insert(u,v); insert(v,u); } dfs1(1); dfs2(1,1); build_tree(1,1,n); for (int i=1;i<=n;i++) change(1,id[i],id[i],c[i]); for (int i=1;i<=m;i++) { char ch[1]; scanf("%s",ch); int a=read(),b=read(),v; if (ch[0]=='C') { v=read(); int t=lca(a,b); solvechange(a,t,v); solvechange(b,t,v); } else { int t=lca(a,b); printf("%d\n",solvesum(a,t)+solvesum(b,t)-1); } } return 0; }