bzoj2243: [SDOI2011]染色
2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 3271 Solved: 1262
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
思路:这题维护比较复杂,LCT和树链剖分都可以写,树链剖分还是好写一些。
如果写树链剖分的话,线段树里记录左右端点颜色,颜色个数和覆盖标记,update时判断对接处的两点颜色是否相同,相同颜色数就等于左半部分+右半部分-1。树链剖分统计答案时也一样,记录左右端点颜色,相同就减1.
如果写LCT,splay记录的东西和树链剖分差不多,但update时要合并两次,上半段和自己及自己和下半段。标记也有两种,一种是把一个点设为根是的翻转,一种是覆盖。
记得开long long...
树链剖分代码:
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long #define ls p<<1 #define rs ((p<<1)|1) using namespace std; const ll maxm=200010,maxn=100010,maxt=maxn<<3; ll n,m,pre[maxm],now[maxn],son[maxm],col[maxn],tot,root; ll fa[maxn],hson[maxn],w[maxn],top[maxn],dep[maxn],size[maxn],T,q,a[maxn]; char s[10]; void add(ll a,ll b){pre[++tot]=now[a],now[a]=tot,son[tot]=b;} struct node{ll lcol,rcol,val,cov;}; struct Tree{ node t[maxt]; void update(ll p){ t[p].lcol=t[ls].lcol; t[p].rcol=t[rs].rcol; t[p].val=t[ls].val+t[rs].val-(t[ls].rcol==t[rs].lcol); } void down(ll p){ if (t[p].cov){ t[ls].cov=t[rs].cov=t[p].cov; t[ls].lcol=t[ls].rcol=t[rs].lcol=t[rs].rcol=t[p].cov; t[ls].val=t[rs].val=1; t[p].cov=0; } } void build(ll p,ll l,ll r){ if (l==r){ t[p].lcol=t[p].rcol=a[l],t[p].val=1; return; } ll mid=(l+r)>>1; build(ls,l,mid),build(rs,mid+1,r); update(p); } void change(ll p,ll l,ll r,ll a,ll b,ll v){ if (l==a&&r==b){ t[p].val=1,t[p].lcol=t[p].rcol=t[p].cov=v; // printf("%d %d %d %d %d\n",l,r,a,b,v); return; } ll mid=(l+r)>>1; down(p); if (b<=mid) change(ls,l,mid,a,b,v); else if (a>mid) change(rs,mid+1,r,a,b,v); else change(ls,l,mid,a,mid,v),change(rs,mid+1,r,mid+1,b,v); update(p); } node query(ll p,ll l,ll r,ll a,ll b){ if (a==l&&b==r) return t[p];//printf("%d %d %d %d %d\n",l,r,a,b,t[p].val), ll mid=(l+r)>>1; down(p);//printf("%d %d %d %d %d\n",p,l,r,a,b); if (b<=mid) return query(ls,l,mid,a,b); else if (a>mid) return query(rs,mid+1,r,a,b); else{ node tmp,t1=query(ls,l,mid,a,mid),t2=query(rs,mid+1,r,mid+1,b); tmp.lcol=t1.lcol,tmp.rcol=t2.rcol; tmp.val=t1.val+t2.val-(t1.rcol==t2.lcol); return tmp; } } }Seg; void dfs(ll x){ size[x]=1,hson[x]=0; for (ll y=now[x];y;y=pre[y]) if (son[y]!=fa[x]){ fa[son[y]]=x,dep[son[y]]=dep[x]+1,dfs(son[y]); if (size[son[y]]>size[hson[x]]) hson[x]=son[y]; size[x]+=size[son[y]]; } } void btree(ll x,ll tp){ w[x]=++m,a[m]=col[x],top[x]=tp; if (hson[x]) btree(hson[x],top[x]); for (ll y=now[x];y;y=pre[y]) if (son[y]!=fa[x]&&son[y]!=hson[x]) btree(son[y],son[y]); } void cover(ll a,ll b,ll c){ ll f1=top[a],f2=top[b]; while (f1!=f2){ if (dep[f1]<dep[f2]) swap(f1,f2),swap(a,b); Seg.change(1,1,m,w[f1],w[a],c); a=fa[f1],f1=top[a]; } if (dep[a]>dep[b]) swap(a,b); // printf("%d %d %d\n",w[a],w[b],c); Seg.change(1,1,m,w[a],w[b],c); } ll answer(ll a,ll b){ ll f1=top[a],f2=top[b],acol=-1,bcol=-1,ans=0;//分别表a端和b端接口的颜色 node tmp; while (f1!=f2){ if (dep[f1]<dep[f2]) swap(f1,f2),swap(acol,bcol),swap(a,b); // printf("%d %d\n",a,f1); tmp=Seg.query(1,1,m,w[f1],w[a]); // printf("%d\n",tmp.val); ans+=tmp.val-(acol==tmp.rcol);//因为a在f1下,所以应该是a与这一段链的右颜色对接 a=fa[f1],f1=top[a],acol=tmp.lcol; } if (dep[a]<dep[b]) swap(a,b),swap(acol,bcol); tmp=Seg.query(1,1,m,w[b],w[a]);//printf("tmp%d\n",tmp.val); ans+=tmp.val-(acol==tmp.rcol)-(bcol==tmp.lcol); return ans; } int main(){ scanf("%d%d",&n,&q);root=1; for (ll i=1;i<=n;i++) scanf("%d",&col[i]),col[i]++; for (ll i=1,a,b;i<=n-1;i++) scanf("%d%d",&a,&b),add(a,b),add(b,a); dfs(root),btree(root,root),Seg.build(1,1,m); // for (ll i=1;i<=n;i++) printf("%d %d %d %d\n",i,hson[i],top[i],col[i]); for (ll i=1,a,b,c;i<=q;i++){ scanf("%s",s); if (s[0]=='C') scanf("%d%d%d",&a,&b,&c),c++,cover(a,b,c); else scanf("%d%d",&a,&b),printf("%d\n",answer(a,b)); } return 0; }