[bzoj4999] This Problem Is Too Simple!
题意:给你一棵n个结点的树,树上的每个点都有一个点权,有m次操作,1 将结点i的权值改为x.2 询问路径(i,j)上有多少个点的权值等于x
题解:
树链剖分+动态加点线段树(主席树?)
直接树链剖分
给每个权值开一棵线段树,维护这个权值在区间\([l,r]\)出现了多少次
这里和普通线段树不同的是,线段树中每个结点的编号是动态开的,这样有利于节省时空复杂度 = =?
注意:线段树的数组要开到100倍......
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#define ll long long
#define N 100010
using namespace std;
int n,qu,e_num,tot,id,z;
int nxt[N<<1],to[N<<1],h[N];
int fa[N],dep[N],top[N],siz[N],son[N],dfn[N],a[N];
int rt[N<<2],tr[N<<7],ls[N<<7],rs[N<<7];
char s[5];
map<int,int> mp;
int gi() {
int x=0,o=1; char ch=getchar();
while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if(ch=='-') o=-1,ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
return o*x;
}
void add(int x, int y) {
nxt[++e_num]=h[x],to[e_num]=y,h[x]=e_num;
}
void dfs1(int u) {
siz[u]=1;
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa[u]) continue;
fa[v]=u,dep[v]=dep[u]+1;
dfs1(v);
if(siz[v]>siz[son[u]]) son[u]=v;
siz[u]+=siz[v];
}
}
void dfs2(int u) {
dfn[u]=++z;
if(son[u]) top[son[u]]=top[u],dfs2(son[u]);
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa[u] || v==son[u]) continue;
top[v]=v,dfs2(v);
}
}
void update(int &x, int l, int r, int qx, int val) {
if(!x) x=++id;
if(l==r) {tr[x]+=val;return;}
int mid=(l+r)>>1;
if(qx<=mid) update(ls[x],l,mid,qx,val);
else update(rs[x],mid+1,r,qx,val);
tr[x]=tr[ls[x]]+tr[rs[x]];
}
int query(int x, int l, int r, int ql, int qr) {
if(ql<=l && r<=qr) return tr[x];
int mid=(l+r)>>1;
if(qr<=mid) return query(ls[x],l,mid,ql,qr);
else if(ql>mid) return query(rs[x],mid+1,r,ql,qr);
else return (query(ls[x],l,mid,ql,mid)+query(rs[x],mid+1,r,mid+1,qr));
}
int solve(int x, int y, int k) {
int ret=0;
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ret+=query(rt[k],1,n,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
ret+=query(rt[k],1,n,dfn[y],dfn[x]);
return ret;
}
int main() {
n=gi(),qu=gi();
for(int i=1; i<=n; i++) a[i]=gi();
for(int i=1; i<n; i++) {
int x=gi(),y=gi();
add(x,y),add(y,x);
}
fa[1]=1,dep[1]=1,top[1]=1;
dfs1(1),dfs2(1);
for(int i=1; i<=n; i++) {
if(!mp[a[i]]) mp[a[i]]=++tot;
update(rt[mp[a[i]]],1,n,dfn[i],1);
}
for(int i=1; i<=qu; i++) {
scanf("%s", s);
if(s[0]=='C') {
int x=gi(),y=gi();
update(rt[mp[a[x]]],1,n,dfn[x],-1);
if(!mp[y]) mp[y]=++tot;
update(rt[mp[y]],1,n,dfn[x],1);
a[x]=y;//错误点
}
else {
int l=gi(),r=gi(),x=gi();
if(!mp[x]) puts("0");
else printf("%d\n", solve(l,r,mp[x]));
}
}
return 0;
}