Tree and Queries Dsu on tree

//题意:给出一棵树,树上每个点有一种颜色,给出多个询问,求以a为根的子树上,染色点数超过k的颜色有多少种
//思路:很明显可以dsu on tree,在写的时候要注意统计答案和清零的复杂度,第一个版本利用二分统计答案,复杂度达到O(nlogn*logn)会被卡掉,
//      再加上memset的O(n)复杂度和大常数更寄,所有要注意优化
//
/*#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int n, m, col[N], ans[N];
vector<int> mp[N];
vector<pair<int, int>> qus[N];//第一维k,第二维问题编号
int sz[N], fa[N], son[N], color, tim;
void dfs1(int x, int f) {
    sz[x] = 1;
    fa[x] = f;
    for (auto y : mp[x]) {
        if (y == f) continue;
        dfs1(y, x);
        sz[x] += sz[y];
        if (sz[y] > sz[son[x]]) son[x] = y;
    }
}
int cnt[N];
void dfs2(int x) {
    cnt[col[x]]++;
    for (auto y : mp[x]) {
        if (y == fa[x]) continue;
        dfs2(y);
    }
}
int temp[N];
void solve(int x, bool kep) {
    for (auto y : mp[x]) {
        if (y == fa[x] || y == son[x]) continue;
        solve(y, 0);
    }
    if (son[x]) solve(son[x], 1);
    
    for (auto y : mp[x]) {
        if (y == fa[x] || y == son[x]) continue;
        dfs2(y);
    }
    cnt[col[x]]++;

    for (int i = 1; i <= color; i++) temp[i] = cnt[i];

    sort(temp, temp + 1 + color);
    for (auto y : qus[x]) {
        int len = lower_bound(temp, temp + color + 1, y.first) - temp;
        ans[y.second] = color - len + 1;
    }

    if (!kep) memset(cnt, 0, sizeof(cnt));
    memset(temp, 0, sizeof(temp));
}
int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> col[i]; 
        color = max(color, col[i]);
    }
    for (int i = 1; i <= n - 1; i++) {
        int a, b; cin >> a >> b;
        mp[a].push_back(b);
        mp[b].push_back(a);
    }
    for (int i = 1; i <= m; i++) {
        int a, b; cin >> a >> b;
        qus[a].push_back({ b,i });
    }
    dfs1(1, 0);
    solve(1, 0);
    for (int i = 1; i <= m; i++) cout << ans[i] << endl;
    return 0;
}*/


#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int n, m, col[N], ans[N];
struct edge {
    int to, nt;
}mp[2 * N];
vector<pair<int, int>> qus[N];//第一维k,第二维问题编号
int sz[N], fa[N], son[N], head[N], num;
void add(int a, int b) {
    mp[++num].to = b;
    mp[num].nt = head[a];
    head[a] = num;
}

void dfs1(int x, int f) {
    sz[x] = 1;
    fa[x] = f;
    for (auto w = head[x]; w; w = mp[w].nt) {
        int y = mp[w].to;
        if (y == f) continue;
        dfs1(y, x);
        sz[x] += sz[y];
        if (sz[y] > sz[son[x]]) son[x] = y;
    }
}
int cnt[N], stk[N];
void dfs2(int x, bool fg) {
    if (fg) {
        cnt[col[x]]++;
        stk[cnt[col[x]]]++;
    }
    else {
        stk[cnt[col[x]]]--;
        cnt[col[x]]--;
    }
    for (auto w = head[x]; w; w = mp[w].nt) {
        int y = mp[w].to;
        if (y == fa[x]) continue;
        dfs2(y, fg);
    }
}
void solve(int x, bool kep) {
    for (auto w = head[x]; w; w = mp[w].nt) {
        int y = mp[w].to;
        if (y == fa[x] || y == son[x]) continue;
        solve(y, 0);
    }
    if (son[x]) solve(son[x], 1);

    for (auto w = head[x]; w; w = mp[w].nt) {
        int y = mp[w].to;
        if (y == fa[x] || y == son[x]) continue;
        dfs2(y, 1);
    }
    cnt[col[x]]++;
    stk[cnt[col[x]]]++;

    for (auto y : qus[x]) ans[y.second] = stk[y.first];

    if (!kep) {
        for (auto w = head[x]; w; w = mp[w].nt) {
            int y = mp[w].to;
            if (y == fa[x]) continue;
            dfs2(y, 0);
        }
        stk[cnt[col[x]]]--;
        cnt[col[x]]--;
    }
}
int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> col[i];
    for (int i = 1; i <= n - 1; i++) {
        int a, b; cin >> a >> b;
        add(a, b);
        add(b, a);
    }
    for (int i = 1; i <= m; i++) {
        int a, b; cin >> a >> b;
        qus[a].push_back({ b,i });
    }
    dfs1(1, 0);
    solve(1, 0);
    for (int i = 1; i <= m; i++) cout << ans[i] << endl;
    return 0;
}

 

posted @ 2023-01-09 15:43  Aacaod  阅读(19)  评论(0编辑  收藏  举报