4999: This Problem Is Too Simple!
Description
给您一颗树,每个节点有个初始值。
现在支持以下两种操作:
- C i x(0<=x<2^31) 表示将i节点的值改为x。
- Q i j x(0<=x<2^31) 表示询问i节点到j节点的路径上有多少个值为x的节点。
解题报告:
用时:1h20min,3WA
简单题,对每一种颜色建一棵树链剖分的数组,可以持久化一下,动态加点,暴力搞搞即可
空间时间复杂度:\(O((n+m)logn)\)
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define RG register
#define il inline
#define iter iterator
#define Max(a,b) ((a)>(b)?(a):(b))
#define Min(a,b) ((a)<(b)?(a):(b))
using namespace std;
const int N=100005;
int n,m,head[N],num=0,to[N<<1],nxt[N<<1],col[N],id[N],DFN=0,b[N<<2],sum=0;
struct question{int flag,x,y,z;}q[N<<1];
void link(int x,int y){nxt[++num]=head[x];to[num]=y;head[x]=num;}
int dep[N],top[N],fa[N],sz[N],son[N],tot;char s[3];
void dfs1(int x){
int u;sz[x]=1;
for(int i=head[x];i;i=nxt[i]){
u=to[i];if(dep[u])continue;
dep[u]=dep[x]+1;fa[u]=x;
dfs1(u);
sz[x]+=sz[u];
if(sz[u]>sz[son[x]])son[x]=u;
}
}
void dfs2(int x,int tp){
top[x]=tp;id[x]=++DFN;
if(son[x])dfs2(son[x],tp);
for(int i=head[x];i;i=nxt[i])
if(to[i]!=son[x] && to[i]!=fa[x])dfs2(to[i],to[i]);
}
int totnode=0,root[N<<2];
struct node{int l,r,s;}t[N*160];
void updata(int &rt,int last,int l,int r,int sa,int to){
rt=++totnode;t[rt]=t[last];
if(l==r){t[rt].s+=to;return ;}
int mid=(l+r)>>1;
if(sa>mid)updata(t[rt].r,t[last].r,mid+1,r,sa,to);
else updata(t[rt].l,t[last].l,l,mid,sa,to);
t[rt].s=t[t[rt].l].s+t[t[rt].r].s;
}
int query(int rt,int l,int r,int sa,int se){
if(l>se || r<sa)return 0;
if(sa<=l && r<=se)return t[rt].s;
int mid=(l+r)>>1;
return query(t[rt].l,l,mid,sa,se)+query(t[rt].r,mid+1,r,sa,se);
}
int solve(int x,int y,int co){
int ret=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ret+=query(root[co],1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
ret+=query(root[co],1,n,id[x],id[y]);
return ret;
}
void work()
{
int x,y;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&col[i]),b[++sum]=col[i];
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
link(x,y);link(y,x);
}
dep[1]=1;dfs1(1);dfs2(1,1);
for(int i=1;i<=m;i++){
scanf("%s%d%d",s,&q[i].x,&q[i].y);
if(s[0]=='C')q[i].flag=0,b[++sum]=q[i].y;
else scanf("%d",&q[i].z),q[i].flag=1,b[++sum]=q[i].z;
}
sort(b+1,b+sum+1);
tot=unique(b+1,b+sum+1)-b-1;
for(int i=1;i<=n;i++){
col[i]=lower_bound(b+1,b+tot+1,col[i])-b;
updata(root[col[i]],root[col[i]],1,n,id[i],1);
}
for(int i=1;i<=m;i++){
x=q[i].x;y=q[i].y;
if(q[i].flag==0){
y=lower_bound(b+1,b+tot+1,y)-b;
updata(root[col[x]],root[col[x]],1,n,id[x],-1);
updata(root[y],root[y],1,n,id[x],1);
col[x]=y;
}
else{
q[i].z=lower_bound(b+1,b+tot+1,q[i].z)-b;
printf("%d\n",solve(x,y,q[i].z));
}
}
}
int main()
{
work();
return 0;
}