luogu P5311 [Ynoi2011] 成都七中
https://www.luogu.com.cn/problem/P5311
首先要注意到一个很重要的性质
x
x
x在原图上的联通快,在点分树上也是一个联通快
挺显然的吧
然后把询问挂到,点分树x能到达的最浅的祖先上
对于每一个点
问题转换为求有多少个点到根的颜色区间[l,r]在询问的区间范围内
乍一看好像是二维偏序,实际上可以离线的话有更好的做法
考虑将询问区间和颜色区间放在一起,按照
l
l
l从大到小排序
在开个桶,记录每个颜色最小的右端点,用树状数组把把对应位置设成1,维护前缀和即可
被教育了QWQ
code:
#include<bits/stdc++.h>
#define N 500050
using namespace std;
struct edge {
int v, nxt;
} e[N << 1];
int p[N], eid;
void init() {
memset(p, - 1, sizeof p);
eid = 0;
}
void insert(int u, int v) {
e[eid].v = v;
e[eid].nxt = p[u];
p[u] = eid ++;
}
struct Q {
int l, r, col, o;
};
vector<Q> q[N], b[N];
int cmp(Q x, Q y) {
if(x.l != y.l) return x.l > y.l;
return x.o < y.o;
}
int col[N];
int size[N], msize[N], vis[N], FA[N];
int sz, mx, rt;
void dfs(int u, int ff) {
size[u] = 1; msize[u] = 0;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(v == ff || vis[v]) continue;
dfs(v, u); size[u] += size[v];
msize[u] = max(msize[u], size[v]);
}
msize[u] = max(msize[u], sz - size[u]);
if(msize[u] < mx) mx = msize[u], rt = u;
}
void dfss(int u, int ff, int mi = 1e9, int mx = -1e9) { //printf("%d %d %d %d\n", u, ff, mi, mx);
mi = min(mi, u), mx = max(mx, u); q[rt].push_back((Q){mi, mx, col[u], 0});
b[u].push_back((Q){mi, mx, rt, 0});
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(v == ff || vis[v]) continue;
dfss(v, u, mi, mx);
}
}
void solve(int u, int ff, int n) {// printf("%d %d %d\n", u, ff, n);
mx = sz = n, rt = 0;
dfs(u, u); u = rt; dfss(u, u);
FA[u] = ff; size[u] = sz;
vis[u] = 1;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(vis[v]) continue;
solve(v, u, size[v]);
}
}
int n, m, ANS[N];
#define lowbit(x) (x & -x)
int t[N], ha[N];
void update(int x, int y) {
for(; x <= n; x += lowbit(x)) t[x] += y;
}
int query(int x) {
int ret = 0;
for(; x; x -= lowbit(x)) ret += t[x];
return ret;
}
void clear(int x) {
for(; x <= n; x += lowbit(x)) t[x] = 0;
}
int main() {
init();
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++) scanf("%d", &col[i]);
for(int i = 1; i < n; i ++) {
int u, v;
scanf("%d%d", &u, &v);
insert(u, v), insert(v, u);
}
solve(1, 0, n);
for(int i = 1; i <= m; i ++) {
int l, r, x;
scanf("%d%d%d", &l, &r, &x);
for(int j = 0; j < b[x].size(); j ++) {
if(b[x][j].l >= l && b[x][j].r <= r) {
q[b[x][j].col].push_back((Q){l, r, i, 1});
// printf(" %d %d %d %d\n", l, r, i, x);
break;
}
}
}
memset(ha, 0x3f, sizeof ha);
for(int i = 1; i <= n; i ++) {
sort(q[i].begin(), q[i].end(), cmp);
for(int j = 0; j < q[i].size(); j ++) {
Q x = q[i][j];
if(x.o) ANS[x.col] = query(x.r);
else {
if(x.r < ha[x.col]) update(ha[x.col], -1), ha[x.col] = x.r, update(ha[x.col], 1);
}
}
for(int j = 0; j < q[i].size(); j ++) {
Q x = q[i][j]; ha[x.col] = n + 1;
if(!x.o) clear(x.r);
}
}
for(int i = 1; i <= m; i ++) printf("%d\n", ANS[i]);
return 0;
}