CF620E New Year Tree(dfs序+线段树)
[link](https://codeforces.com/contest/620/problem/E)
// Problem: CF620E New Year Tree // Contest: Luogu // URL: https://www.luogu.com.cn/problem/CF620E // Memory Limit: 250 MB // Time Limit: 3000 ms // // Powered by CP Editor (https://cpeditor.org) #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll, ll>PLL; typedef pair<int, int>PII; typedef pair<double, double>PDD; #define I_int ll inline ll read() { ll x = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-')f = -1; ch = getchar(); } while(ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } return x * f; } inline void out_(ll x){ if (x < 0) x = ~x + 1, putchar('-'); if (x > 9) out_(x / 10); putchar(x % 10 + '0'); } inline void write(ll x){ if (x < 0) x = ~x + 1, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); puts(""); } #define read read() #define closeSync ios::sync_with_stdio(0);cin.tie(0);cout.tie(0) #define multiCase int T;cin>>T;for(int t=1;t<=T;t++) #define rep(i,a,b) for(int i=(a);i<=(b);i++) #define repp(i,a,b) for(int i=(a);i<(b);i++) #define per(i,a,b) for(int i=(a);i>=(b);i--) #define perr(i,a,b) for(int i=(a);i>(b);i--) ll ksm(ll a, ll b, ll p) { ll res = 1; while(b) { if(b & 1)res = res * a % p; a = a * a % p; b >>= 1; } return res; } const int inf = 0x3f3f3f3f; #define PI acos(-1) const int maxn=4e5+100; int n,m; ll c[maxn]; int e[maxn*2],ne[maxn*2],h[maxn],idx; int in[maxn],out[maxn],tot,pos[maxn]; struct node{ int l,r; ll laz,sum; }tr[maxn*4]; ll lowbit(ll x){ return x&(-x); } void add(int u,int v){ e[++idx]=v,ne[idx]=h[u],h[u]=idx; } void dfs(int u,int fa){ in[u]=++tot; pos[tot]=u; /*for(int i=0;i<g[u].size();i++){ int j=g[u][i]; if(j==fa) continue; dfs(j,u); }*/ for(int i=h[u];i+1;i=ne[i]){ int j=e[i]; if(j==fa) continue; dfs(j,u); } out[u]=tot; } void pushup(int u){ tr[u].sum=tr[u<<1].sum|tr[u<<1|1].sum; return ; } void pushdown(int u){ if(tr[u].laz!=0){ tr[u<<1].laz=tr[u<<1|1].laz=tr[u].laz; tr[u<<1].sum=tr[u<<1|1].sum=(tr[u].laz); tr[u].laz=0; } } void build(int u,int l,int r){ tr[u].l=l,tr[u].r=r; if(l==r){ tr[u].laz=0; tr[u].sum=(1ll<<c[pos[l]]); return ; } int mid=(l+r)/2; build(u<<1,l,mid); build(u<<1|1,mid+1,r); pushup(u); } void update(int u,int l,int r,ll v){ //if(tr[u].l>r||tr[u].r<l) return ; if(tr[u].l>=l&&tr[u].r<=r){ tr[u].sum=(ll)(1ll<<v); tr[u].laz=(ll)(1ll<<v); return ; } pushdown(u); int mid=(tr[u].l+tr[u].r)/2; if(l<=mid) update(u<<1,l,r,v); if(r>mid) update(u<<1|1,l,r,v); pushup(u); } ll qask(int u,int l,int r){ //if(tr[u].l>r||tr[u].r<l) return 0; if(tr[u].l>=l&&tr[u].r<=r){ return tr[u].sum; } pushdown(u); ll res=0; int mid=(tr[u].l+tr[u].r)/2; if(l<=mid) res|= qask(u<<1,l,r); if(r>mid) res|=qask(u<<1|1,l,r); pushup(u); return res; } int main(){ memset(h,-1,sizeof h); n=read,m=read; rep(i,1,n) c[i]=read; rep(i,1,n-1){ int u=read,v=read; add(u,v);add(v,u); } dfs(1,0); build(1,1,n); /*rep(i,1,n){ cout<<in[i]<<"++++++++"<<out[i]<<"\n"; }*/ while(m--){ int op=read; if(op==1){ int u=read; ll v=read; update(1,in[u],out[u],v); } else{ int u=read; ll res=qask(1,in[u],out[u]); //cout<<res<<"-******"<<endl; ll ans=0; while(res){ ans++; res-=lowbit(res); } printf("%lld\n",ans); } } return 0; }