bzoj 2243: [SDOI2011]染色 (树链剖分+线段树 区间合并)
2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 9854 Solved: 3725
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
思路:
有些区间合并的思想,两个区间不能直接相加,需要比较下他们要相接的两个端点的颜色是否相同,如果相同那么相加的值就要-1,不相同的话直接相加就好了因为是在树上操作,需要用树链剖分处理下,而且更新和查询操作都需要特殊处理下。如果不是在树上操作的话就是一道很简单的线段树了,加了数剖复杂了好多啊。。。
之前lazy标记一直忘了下传,。。找了一天的错。。。
实现代码;
#include<bits/stdc++.h> using namespace std; #define ll long long #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define mid ll m = (l+r)>>1 const ll M = 1e5+10; ll cnt,n,q; ll siz[M],son[M],fa[M],top[M],rk[M],tid[M],dep[M],a[M],cnt1,head[M],lazy[M<<2]; struct node{ll to,next;}e[M]; struct node1{ll ls,rs,val;}; node1 sum[M<<2]; void add(ll u,ll v){ e[++cnt1].to = v;e[cnt1].next = head[u];head[u] = cnt1; e[++cnt1].to = u;e[cnt1].next = head[v];head[v] = cnt1; } void dfs1(ll u,ll faz,ll deep){ dep[u] = deep; fa[u] = faz; siz[u] = 1; for(ll i = head[u];i;i=e[i].next){ ll v = e[i].to; if(v != fa[u]){ dfs1(v,u,deep+1); siz[u] += siz[v]; if(son[u] == -1||siz[v] > siz[son[u]]) son[u] = v; } } } void dfs2(ll u,ll t){ top[u] = t; tid[u] = cnt; rk[cnt] = u; cnt++; if(son[u] == -1) return; dfs2(son[u],t); for(ll i = head[u];i;i = e[i].next){ ll v = e[i].to; if(v != son[u]&&v != fa[u]) dfs2(v,v); } } void pushup(ll rt){ sum[rt].ls = sum[rt<<1].ls; sum[rt].rs = sum[rt<<1|1].rs; if(sum[rt<<1].rs==sum[rt<<1|1].ls) sum[rt].val = sum[rt<<1].val+sum[rt<<1|1].val-1; else sum[rt].val = sum[rt<<1].val + sum[rt<<1|1].val; } void build(ll l,ll r,ll rt){ if(l == r){ sum[rt].ls = a[rk[l]]; sum[rt].rs = a[rk[l]]; sum[rt].val = 1; ////cout<<l<<" "<<rk[l]<<" "<<a[rk[l]]<<endl; return ; } mid; build(lson); build(rson); pushup(rt); } void pushdown(ll rt){ if(lazy[rt]){ sum[rt<<1].ls = lazy[rt]; sum[rt<<1|1].rs = lazy[rt]; sum[rt<<1|1].ls = lazy[rt]; sum[rt<<1].rs = lazy[rt]; sum[rt<<1].val = 1; sum[rt<<1|1].val = 1; lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt]; lazy[rt] = 0; } } void update(ll L,ll R,ll c,ll l,ll r,ll rt){ if(L <= l&&R >= r){ sum[rt].val = 1; lazy[rt] = c; sum[rt].ls = c; sum[rt].rs = c; return ; } pushdown(rt); ll m = (l + r) >> 1; if(L <= m) update(L,R,c,lson); if(R > m) update(L,R,c,rson); pushup(rt); } node1 query(ll L,ll R,ll l,ll r,ll rt){ if(L <= l&&R >= r){ return sum[rt]; } pushdown(rt); ll m = (l + r) >> 1; if(L > m) return query(L,R,rson); if(R <= m) return query(L,R,lson); node1 t1 = query(L,m,lson); node1 t2 = query(m+1,R,rson); node1 t; t.ls = t1.ls;t.rs = t2.rs; if(t1.rs==t2.ls) t.val = t1.val+t2.val-1; else t.val = t1.val+t2.val; return t; } void cover(ll x,ll y,ll c){ ll fx = top[x],fy = top[y]; while(fx!=fy){ if(dep[fx] < dep[fy]) swap(fx,fy),swap(x,y); update(tid[fx],tid[x],c,1,n,1); x = fa[fx];fx = top[x]; } if(dep[x] < dep[y]) swap(x,y); update(tid[y],tid[x],c,1,n,1); } ll ask(ll x,ll y){ ll sum = 0; ll lc = -1,rc=-1; ll fx = top[x],fy = top[y]; node1 t; while(fx != fy){ if(dep[fx] < dep[fy]){ swap(fx,fy); swap(x,y); swap(lc,rc); } t = query(tid[fx],tid[x],1,n,1); sum += t.val - (lc==t.rs); x = fa[fx]; fx = top[x]; lc = t.ls; } if(dep[x] < dep[y]) swap(x,y),swap(lc,rc); t = query(tid[y],tid[x],1,n,1); sum += t.val - (lc==t.rs) - (rc==t.ls); //当前是x-y区间与两端的区间相加,所以需要判两个 return sum; } int main() { ll u,v,x,y,m,z; memset(son,-1,sizeof(son)); scanf("%lld%lld",&n,&m); cnt = 1;cnt1 = 1; for(ll i = 1;i <= n;i ++) { scanf("%lld",&x); a[i] = x+1; } for(ll i = 0;i < n-1;i++){ scanf("%lld%lld",&u,&v); add(u,v); } dfs1(1,0,1); dfs2(1,1); build(1,n,1); char op[10]; while(m--){ scanf("%s",op); if(op[0] == 'Q'){ scanf("%lld%lld",&x,&y); printf("%lld\n",ask(x,y)); } else { scanf("%lld%lld%lld",&x,&y,&z); z++; cover(x,y,z); } } return 0; }