BZOJ 2243 [SDOI2011]染色 (树链剖分)(线段树区间修改)
[SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MB
Submit: 6870 Solved: 2546
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
【分析】此题的难点在于处理颜色块的个数。考虑到若两个区间合并,颜色块的个数取决于两边的个数和合并处的两个接口处颜色是否相等,相等则-1.所以对于每个线段树结点(代表区间)维护三个数组--sum[]:此区间的颜色块数;s[]:此区间左端的颜色;t[]:此区间右端的颜色。所以区间合并时,若左儿子的右端点与右儿子的左端点相等,则
sum[rt]=sum[rt*2]+sum[rt*2+1]-1,否则不-1.区间修改采用lazy[]标记。(敲了一下午,调了一晚上,最后发现是个很SB的错误,日。。。)
#include <iostream> #include <cstring> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #include <time.h> #include <string> #include <map> #include <stack> #include <vector> #include <set> #include <queue> #define met(a,b) memset(a,b,sizeof a) #define pb push_back #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 using namespace std; typedef long long ll; const int N=2e5+50; const int M=N*N+10; int dep[N],siz[N],fa[N],id[N],son[N],val[N],top[N],c[N]; int num,m,n,q,tot=0; int sum[N*2]; int lazy[N*2],head[N],s[N*2],t[N*2]; struct tree { int to,next; } edg[N*2]; void add(int u,int v) { edg[tot].to=v; edg[tot].next=head[u]; head[u]=tot++; } void dfs1(int u, int f, int d) { dep[u] = d; siz[u] = 1; son[u] = 0; fa[u] = f; for (int i = head[u]; i != -1; i=edg[i].next) { int ff = edg[i].to; if (ff == f) continue; dfs1(ff, u, d + 1); siz[u] += siz[ff]; if (siz[son[u]] < siz[ff]) son[u] = ff; } } void dfs2(int u, int tp) { top[u] = tp; id[u] = ++num; if (son[u]) dfs2(son[u], tp); for (int i = head[u]; i != -1; i=edg[i].next) { int ff = edg[i].to; if (ff == fa[u] || ff == son[u]) continue; dfs2(ff, ff); } } void Push_up(int rt) { if(t[rt*2]==s[rt*2+1])sum[rt]=sum[rt*2]+sum[rt*2+1]-1; else sum[rt]=sum[rt*2]+sum[rt*2+1]; s[rt]=s[rt*2]; t[rt]=t[2*rt+1]; } void Push_down(int rt) { if(lazy[rt]) { s[2*rt]=s[2*rt+1]=t[2*rt]=t[2*rt+1]=lazy[rt]; lazy[2*rt]=lazy[rt]; lazy[2*rt+1]=lazy[rt]; sum[2*rt]=sum[2*rt+1]=1; lazy[rt]=0; } } void Build(int l,int r,int rt) { if(l==r) { lazy[rt]=s[rt]=t[rt]=val[l]; sum[rt]=1; return; } Push_down(rt); int m=(l+r)>>1; Build(lson); Build(rson); Push_up(rt); } void Update(int L,int R,int l,int r,int rt,int add) { if(l>=L&&r<=R) { lazy[rt]=add; s[rt]=t[rt]=add; sum[rt]=1; return; } Push_down(rt); int m=(r+l)>>1; if(L<=m)Update(L,R,lson,add); if(R>m) Update(L,R,rson,add); Push_up(rt); } int Query(int L,int R,int l,int r,int rt) { if(L<=l&&r<=R)return sum[rt]; Push_down(rt); int m=(l+r)>>1,ans=0; if(L<=m)ans+=Query(L,R,lson); if(R>m)ans+=Query(L,R,rson); if(L<=m && R>m && s[rt*2+1]==t[rt*2])ans--; return ans; } void solve(int u,int v,int add) { int tp1 = top[u], tp2 = top[v]; while (tp1 != tp2) { if (dep[tp1] < dep[tp2]) { swap(tp1, tp2); swap(u, v); } Update(id[tp1],id[u],1,n,1,add); u = fa[tp1]; tp1 = top[u]; } if (dep[u] > dep[v]) swap(u, v); Update(id[u],id[v],1,n,1,add); return; } int _find(int u,int l,int r,int rt){ if(l==r)return s[rt]; Push_down(rt); int m=(l+r)/2; if(u<=m)return _find(u,l,m,rt*2); else return _find(u,m+1,r,rt*2+1); } int Answer(int u,int v) { int tp1 = top[u], tp2 = top[v]; int ans=0; while (tp1 != tp2) { if (dep[tp1] < dep[tp2]) { swap(tp1, tp2); swap(u, v); } ans +=Query(id[tp1], id[u],1,n,1); u = fa[tp1]; if(_find(id[tp1],1,n,1)==_find(id[u],1,n,1))ans--; tp1 = top[u]; } if (dep[u] > dep[v]) swap(u, v); ans += Query(id[u], id[v],1,n,1); return ans; } int main() { scanf("%d%d",&n,&m); met(head,-1); int u,v,w; for(int i=1; i<=n; i++) { scanf("%d",&c[i]); } for(int i=1; i<n; i++) { scanf("%d%d",&u,&v); add(u,v); add(v,u); } num = 0; dfs1(1,0,1); dfs2(1,1); for (int i = 1; i <= n; i++) { val[id[i]]=c[i]; } Build(1,num,1); char str[20]; while(m--) { scanf("%s",str); if(str[0]=='C') { scanf("%d%d%d",&u,&v,&w); solve(u,v,w); } else { scanf("%d%d",&u,&v); printf("%d\n",Answer(u,v)); } } return 0; }