Codeforces 620E New Year Tree(线段树+位运算)

题目链接 New Year Tree

考虑到$ck <= 60$,那么用位运算统计颜色种数

对于每个点,重新标号并算出他对应的进和出的时间,然后区间更新+查询。

用线段树来维护。

  1 #include <bits/stdc++.h>
  2 
  3 using namespace std;
  4 
  5 #define rep(i, a, b) for (int i(a); i <= (b); ++i)
  6 
  7 struct node{
  8     long long num, lazy;
  9 } tree[400010 << 2];
 10 
 11 struct Node{
 12     int l, r;
 13 } e[400010];
 14 
 15 vector <int> v[400010];
 16 
 17 int n, m;
 18 long long val[400010], c[400010];
 19 int Time;
 20 bool vis[400010];
 21 long long ans, cover;
 22 int op;
 23 int x, y;
 24 
 25 void dfs(int x, int fa){
 26     e[x].l = ++Time;
 27     val[Time] = c[x];
 28     vis[x] = true;
 29     for (auto u : v[x]){
 30         if (u == fa) continue;
 31         dfs(u, x);
 32     }
 33 
 34     e[x].r = Time;
 35 }
 36 
 37 inline void pushup(int i){
 38     tree[i].num = tree[i << 1].num | tree[i << 1 | 1].num;
 39 }
 40 
 41 inline void pushdown(int i){
 42     if (tree[i].lazy){
 43         tree[i << 1].num = tree[i << 1 | 1].num = (1LL << tree[i].lazy);
 44         tree[i << 1].lazy = tree[i << 1 | 1].lazy = tree[i].lazy;
 45         tree[i].lazy = 0;
 46     }
 47 }
 48 
 49 void build(int i, int l, int r){
 50     tree[i].lazy = 0;
 51     if (l == r){
 52         tree[i].num = (1LL << val[l]);
 53         return ;
 54     }
 55 
 56     int mid = (l + r) >> 1;
 57     build(i << 1, l, mid);
 58     build(i << 1 | 1, mid + 1, r);
 59     pushup(i);
 60 }
 61 
 62 void update(int i, int L, int R, int l, int r, long long cover){
 63     if (l <= L && R <= r){
 64         tree[i].lazy = cover;
 65         tree[i].num = (1LL << cover);
 66         return ;
 67     }
 68 
 69     int mid = (L + R) >> 1;
 70     pushdown(i);
 71     if (l <= mid) update(i << 1, L, mid, l, r, cover);
 72     if (r > mid) update(i << 1 | 1, mid + 1, R, l, r, cover);
 73     pushup(i);
 74 }
 75 
 76 void solve(int i, int L, int R, int l, int r){
 77     if (l <= L && R <= r){
 78         ans |= tree[i].num;
 79         return;
 80     }
 81 
 82     pushdown(i);
 83     int mid = (L + R) >> 1;
 84     if (l <= mid) solve(i << 1, L, mid, l, r);
 85     if (r > mid) solve(i << 1 | 1, mid + 1, R, l, r);
 86 }
 87 
 88 int main(){
 89 
 90     scanf("%d%d", &n, &m);
 91 
 92     rep(i, 1, n) v[i].clear();
 93     rep(i, 1, n) scanf("%lld", c + i);
 94     rep(i, 1, n - 1){
 95         scanf("%d%d", &x, &y);
 96         v[x].push_back(y);
 97         v[y].push_back(x);
 98     }
 99 
100     memset(vis, 0, sizeof vis); Time = 0;
101     dfs(1, 0);
102     build(1, 1, n);
103 
104     rep(i, 1, m){
105         scanf("%d%d", &op, &x);
106         if (op == 1){
107             scanf("%lld", &cover);
108             update(1, 1, n, e[x].l, e[x].r, cover);
109         }
110 
111         else{
112             ans = 0;
113             solve(1, 1, n, e[x].l, e[x].r);
114             int ret = 0;
115             for (; ans; ans -= ans & -ans) ++ret;
116             printf("%d\n", ret);
117         }
118     }
119 
120     return 0;
121 }

 

posted @ 2017-05-02 21:37  cxhscst2  阅读(225)  评论(0编辑  收藏  举报