luogu P5311 [Ynoi2011] 成都七中

https://www.luogu.com.cn/problem/P5311

首先要注意到一个很重要的性质
x x x在原图上的联通快,在点分树上也是一个联通快
挺显然的吧
然后把询问挂到,点分树x能到达的最浅的祖先上
对于每一个点
问题转换为求有多少个点到根的颜色区间[l,r]在询问的区间范围内

乍一看好像是二维偏序,实际上可以离线的话有更好的做法

考虑将询问区间和颜色区间放在一起,按照 l l l从大到小排序
在开个桶,记录每个颜色最小的右端点,用树状数组把把对应位置设成1,维护前缀和即可
被教育了QWQ
code:

#include<bits/stdc++.h>
#define N 500050
using namespace std;
struct edge {
    int v, nxt;
} e[N << 1];
int p[N], eid;
void init() {
    memset(p, - 1, sizeof p);
    eid = 0;
}
void insert(int u, int v) {
    e[eid].v = v;
    e[eid].nxt = p[u];
    p[u] = eid ++;
}
struct Q {
    int l, r, col, o;
};
vector<Q> q[N], b[N];
int cmp(Q x, Q y) {
    if(x.l != y.l) return x.l > y.l;
    return x.o < y.o;
}
int col[N];
int size[N], msize[N], vis[N], FA[N];
int sz, mx, rt;
void dfs(int u, int ff) {
    size[u] = 1; msize[u] = 0;
    for(int i = p[u]; i + 1; i = e[i].nxt) {
        int v = e[i].v;
        if(v == ff || vis[v]) continue;
        dfs(v, u); size[u] += size[v];
        msize[u] = max(msize[u], size[v]);
    }
    msize[u] = max(msize[u], sz - size[u]);
    if(msize[u] < mx) mx = msize[u], rt = u;
}
void dfss(int u, int ff, int mi = 1e9, int mx = -1e9) { //printf("%d %d %d %d\n", u, ff, mi, mx);
    mi = min(mi, u), mx = max(mx, u); q[rt].push_back((Q){mi, mx, col[u], 0});
    b[u].push_back((Q){mi, mx, rt, 0});
    for(int i = p[u]; i + 1; i = e[i].nxt) {
        int v = e[i].v;
        if(v == ff || vis[v]) continue;
        dfss(v, u, mi, mx);
    }
}
void solve(int u, int ff, int n) {// printf("%d %d %d\n", u, ff, n);
    mx = sz = n, rt = 0;
    dfs(u, u); u = rt; dfss(u, u);
    FA[u] = ff; size[u] = sz;
    vis[u] = 1;
    for(int i = p[u]; i + 1; i = e[i].nxt) {
        int v = e[i].v;
        if(vis[v]) continue;
        solve(v, u, size[v]);
    }
}
int n, m, ANS[N];
#define lowbit(x) (x & -x)
int t[N], ha[N];
void update(int x, int y) {
    for(; x <= n; x += lowbit(x)) t[x] += y;
}
int query(int x) {
    int ret = 0;
    for(; x; x -= lowbit(x)) ret += t[x];
    return ret;
}
void clear(int x) {
    for(; x <= n; x += lowbit(x)) t[x] = 0;
}
int main() {
    init();
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i ++) scanf("%d", &col[i]);
    for(int i = 1; i < n; i ++) {
        int u, v;
        scanf("%d%d", &u, &v);
        insert(u, v), insert(v, u);
    }
    solve(1, 0, n);
    for(int i = 1; i <= m; i ++) {
        int l, r, x;
        scanf("%d%d%d", &l, &r, &x);
        for(int j = 0; j < b[x].size(); j ++) {
            if(b[x][j].l >= l && b[x][j].r <= r) {
                q[b[x][j].col].push_back((Q){l, r, i, 1});
            //    printf("  %d %d %d %d\n", l, r, i, x);
                break;
            }
        }
    }
    memset(ha, 0x3f, sizeof ha);
    for(int i = 1; i <= n; i ++) {
        sort(q[i].begin(), q[i].end(), cmp);
        for(int j = 0; j < q[i].size(); j ++) {
            Q x = q[i][j];
            if(x.o) ANS[x.col] = query(x.r);
            else {
                if(x.r < ha[x.col]) update(ha[x.col], -1), ha[x.col] = x.r, update(ha[x.col], 1);
            }
        }
        for(int j = 0; j < q[i].size(); j ++) {
            Q x = q[i][j]; ha[x.col] = n + 1;
            if(!x.o) clear(x.r);
        }
    }
    for(int i = 1; i <= m; i ++) printf("%d\n", ANS[i]);
    return 0;
}
posted @ 2021-08-06 13:21  lahlah  阅读(38)  评论(0编辑  收藏  举报