查看代码
#include <bits/stdc++.h>
using namespace std;
#define _for(i,a,b) for(int i = (a);i <= (b);++i)
typedef long long ll;
const int maxn = 1e5+5;
const int mod = 1e9+7;
ll qpow(ll a,ll b){ll res = 1;for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
struct graph
{
int head[maxn],nxt[maxn<<1],to[maxn<<1],w[maxn<<1],sz;
void init(){memset(head,-1,sizeof(head));}
graph(){init();}
void push(int a,int b,int c=0){nxt[sz]=head[a],to[sz]=b,w[sz]=c,head[a]=sz++;}
int& operator[](const int a){return to[a];}
}g;
int size[maxn],son[maxn],top[maxn],dfn[maxn],rnk[maxn],dep[maxn],fa[maxn],tot;
void dfs1(int now,int pre)
{
dep[now] = dep[pre]+1;
fa[now] = pre;
size[now] = 1;
for(int i = g.head[now];~i;i = g.nxt[i]){
if(g[i]==pre)continue;
dfs1(g[i],now);
size[now] += size[g[i]];
if(size[g[i]]>size[son[now]])son[now] = g[i];
}
}
void dfs2(int now,int tp)
{
top[now] = tp;
dfn[++tot] = now;
rnk[now] = tot;
if(son[now])dfs2(son[now],tp);
for(int i = g.head[now];~i;i = g.nxt[i]){
if(g[i]==fa[now]||g[i]==son[now])continue;
dfs2(g[i],g[i]);
}
}
struct node
{
int lc,rc,sum;
node(){}
node(int _l,int _r,int _sum){
lc = _l,rc = _r,sum = _sum;
}
};
int color[maxn];
struct Segment_tree
{
node tree[maxn<<2];
int lazy[maxn<<2];
void build(int root,int l,int r){
if(l==r){
tree[root].lc = color[dfn[l]];
tree[root].rc = color[dfn[r]];
tree[root].sum = 1;
return;
}
int mid = l+r>>1;
build(root<<1,l,mid);
build(root<<1|1,mid+1,r);
tree[root] = pushup(tree[root<<1],tree[root<<1|1]);
}
void pushdown(int root){
if(lazy[root]){
tree[root<<1].lc = lazy[root];
tree[root<<1].rc = lazy[root];
tree[root<<1].sum = 1;
tree[root<<1|1].lc = lazy[root];
tree[root<<1|1].rc = lazy[root];
tree[root<<1|1].sum = 1;
lazy[root<<1] = lazy[root<<1|1] = lazy[root];
lazy[root] = 0;
}
}
node pushup(node a,node b){
if(a.sum==0)return b;
if(b.sum==0)return a;
node tmp;
tmp.lc = a.lc;
tmp.rc = b.rc;
tmp.sum = a.sum+b.sum;
if(a.rc==b.lc)tmp.sum--;
return tmp;
}
void modify(int root,int l,int r,int ml,int mr,int col){
if(l >= ml&&r <= mr){
tree[root].lc = col;
tree[root].rc = col;
lazy[root] = col;
tree[root].sum = 1;
return;
}
pushdown(root);
int mid = l+r>>1;
if(mid>=ml)modify(root<<1,l,mid,ml,mr,col);
if(mr>mid)modify(root<<1|1,mid+1,r,ml,mr,col);
tree[root] = pushup(tree[root<<1],tree[root<<1|1]);
}
node query(int root,int l,int r,int ql,int qr){
if(l >= ql&&r <= qr)return tree[root];
pushdown(root);
int mid = l+r>>1;
node tmp;
if(mid<ql){
tmp = query(root<<1|1,mid+1,r,ql,qr);
}
else if(qr<=mid)tmp = query(root<<1,l,mid,ql,qr);
else tmp = pushup(query(root<<1,l,mid,ql,qr),query(root<<1|1,mid+1,r,ql,qr));
return tmp;
}
}sg;
int n;
void query(int l,int r)
{
node tmp1,tmp2;
tmp1.sum = 0;
tmp2.sum = 0;
int x = l,y = r;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]){
node a = sg.query(1,1,n,rnk[top[x]],rnk[x]);
tmp1 = sg.pushup(a,tmp1);
x = fa[top[x]];
}
else{
node b = sg.query(1,1,n,rnk[top[y]],rnk[y]);
tmp2 = sg.pushup(b,tmp2);
y = fa[top[y]];
}
}
if(dep[x]>dep[y]){
node a = sg.query(1,1,n,rnk[y],rnk[x]);
tmp1 = sg.pushup(a,tmp1);
}
else {
node b = sg.query(1,1,n,rnk[x],rnk[y]);
tmp2 = sg.pushup(b,tmp2);
}
int ans = tmp1.sum+tmp2.sum;
if(tmp1.sum==0||tmp2.sum==0);
else if(tmp1.lc==tmp2.lc)ans--;
printf("%d\n",ans);
}
void modify(int l,int r,int c)
{
int x = l,y = r;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
sg.modify(1,1,n,rnk[top[x]],rnk[x],c);
x = fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
sg.modify(1,1,n,rnk[x],rnk[y],c);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("simple.in","r",stdin);
freopen("simple.out","w",stdout);
#endif
int m;
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;++i)scanf("%d",&color[i]);
for(int i = 1,a,b;i < n;++i){
scanf("%d%d",&a,&b);
g.push(a,b);
g.push(b,a);
}
dfs1(1,0);
dfs2(1,1);
sg.build(1,1,n);
while(m--){
char opt;
scanf(" %c",&opt);
if(opt=='Q'){
int l,r;
scanf("%d%d",&l,&r);
query(l,r);
}
else {
int l,r,c;
scanf("%d%d%d",&l,&r,&c);
modify(l,r,c);
}
}
return 0;
}