bzoj 2243: [SDOI2011]染色 线段树区间合并+树链剖分
2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 7925 Solved: 2975
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<iostream> #include<cstdio> #include<cmath> #include<string> #include<queue> #include<algorithm> #include<stack> #include<cstring> #include<vector> #include<list> #include<set> #include<map> using namespace std; #define LL long long #define pi (4*atan(1.0)) #define eps 1e-8 #define bug(x) cout<<"bug"<<x<<endl; const int N=3e5+10,M=2e6+10,inf=1e9+10; const LL INF=1e18+10,mod=1e9+7; struct edge { int v,next; } edge[N<<1]; int head[N<<1],edg,id,n; /// 树链剖分 int fa[N],dep[N],son[N],siz[N]; // fa父亲,dep深度,son重儿子,siz以该点为子树的节点个数 int a[N],ran[N],top[N],tid[N]; // tid表示边的标号,top通过重边可以到达最上面的点,ran表示标记tid void init() { memset(son,-1,sizeof(son)); memset(head,-1,sizeof(head)); edg=0; id=0; } void add(int u,int v) { edg++; edge[edg].v=v; edge[edg].next=head[u]; head[u]=edg; } void dfs1(int u,int fath,int deep) { fa[u]=fath; siz[u]=1; dep[u]=deep; for(int i=head[u]; i!=-1; i=edge[i].next) { int v=edge[i].v; if(v==fath)continue; dfs1(v,u,deep+1); siz[u]+=siz[v]; if(son[u]==-1||siz[v]>siz[son[u]]) son[u]=v; } } void dfs2(int u,int tp) { tid[u]=++id; top[u]=tp; ran[tid[u]]=u; if(son[u]==-1)return; dfs2(son[u],tp); for(int i=head[u]; i!=-1; i=edge[i].next) { int v=edge[i].v; if(v==fa[u])continue; if(v!=son[u]) dfs2(v,v); } } struct SGT { int la[N<<2],ra[N<<2],ma[N<<2],lazy[N<<2]; void pushup(int pos) { if(ra[pos<<1]==la[pos<<1|1])ma[pos]=ma[pos<<1]+ma[pos<<1|1]-1; else ma[pos]=ma[pos<<1|1]+ma[pos<<1]; la[pos]=la[pos<<1]; ra[pos]=ra[pos<<1|1]; } void pushdown(int pos) { if(lazy[pos]) { la[pos<<1]=la[pos<<1|1]=lazy[pos]; ra[pos<<1]=ra[pos<<1|1]=lazy[pos]; ma[pos<<1]=ma[pos<<1|1]=1; lazy[pos<<1]=lazy[pos<<1|1]=lazy[pos]; lazy[pos]=0; } } pair<int,pair<int,int> > Union( pair<int,pair<int,int> > a, pair<int,pair<int,int> > b) { if(a.second.second==b.second.first) return make_pair(a.first+b.first-1,make_pair(a.second.first,b.second.second)); return make_pair(a.first+b.first,make_pair(a.second.first,b.second.second)); } void build(int l,int r,int pos) { lazy[pos]=0; if(l==r) { la[pos]=ra[pos]=a[ran[l]]; ma[pos]=1; return; } int mid=(l+r)>>1; build(l,mid,pos<<1); build(mid+1,r,pos<<1|1); pushup(pos); } void update(int L,int R,int c,int l,int r,int pos) { if(L<=l&&r<=R) { lazy[pos]=c; la[pos]=ra[pos]=c; ma[pos]=1; return; } pushdown(pos); int mid=(l+r)>>1; if(L<=mid)update(L,R,c,l,mid,pos<<1); if(R>mid)update(L,R,c,mid+1,r,pos<<1|1); pushup(pos); } pair<int,pair<int,int> > query(int L,int R,int l,int r,int pos) { if(L<=l&&r<=R) return make_pair(ma[pos],make_pair(la[pos],ra[pos])); pushdown(pos); int mid=(l+r)>>1; if(L>mid)return query(L,R,mid+1,r,pos<<1|1); else if(R<=mid)return query(L,R,l,mid,pos<<1); else { pair<int,pair<int,int> > a=query(L,mid,l,mid,pos<<1); pair<int,pair<int,int> > b=query(mid+1,R,mid+1,r,pos<<1|1); return Union(a,b); } } }tree; int lca(int l,int r) { while(top[l]!=top[r]) { if(dep[top[l]]<dep[top[r]])swap(l,r); l=fa[top[l]]; } if(dep[l]<dep[r])swap(l,r); return r; } int up(int l,int r) { int pre=-1,ans=0; while(top[l]!=top[r]) { if(dep[top[l]]<dep[top[r]])swap(l,r); pair<int,pair<int,int> > x=tree.query(tid[top[l]],tid[l],1,n,1); //cout<<tid[top[l]]<<" "<<tid[l]<<" "<<x.first<<endl; ans+=x.first; if(pre==x.second.second)ans--; pre=x.second.first; l=fa[top[l]]; } if(dep[l]<dep[r])swap(l,r); pair<int,pair<int,int> > x=tree.query(tid[r],tid[l],1,n,1); //cout<<tid[r]<<" "<<tid[l]<<" "<<x.first<<endl; ans+=x.first; if(pre==x.second.second)ans--; return ans; } void go(int l,int r,int c) { while(top[l]!=top[r]) { if(dep[top[l]]<dep[top[r]])swap(l,r); tree.update(tid[top[l]],tid[l],c,1,n,1); l=fa[top[l]]; } if(dep[l]<dep[r])swap(l,r); tree.update(tid[r],tid[l],c,1,n,1); } char ch[10]; int main() { init(); int q; scanf("%d%d",&n,&q); for(int i=1;i<=n;i++) scanf("%d",&a[i]); for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } dfs1(1,-1,1); dfs2(1,1); tree.build(1,n,1); while(q--) { int u,v; scanf("%s%d%d",ch,&u,&v); if(ch[0]=='C') { int c; scanf("%d",&c); go(u,v,c); } else { int x=lca(u,v); printf("%d\n",up(u,x)+up(v,x)-1); } } return 0; }