Loading

【题解】[AHOI2022] 钥匙

提供一个简单树剖/差分的做法。

我们先考虑 \(A\) 性质怎么做,令每种颜色的钥匙和箱子分别为 \((x,y)\),那么会对起点在 \(x\) 一端,终点在 \(y\) 一端的询问产生贡献。所以我们对询问离线,问题转化为矩阵加单点查询,直接扫描线即可。

没有特殊性质,但我们延续上面的思路,考虑对同颜色的钥匙和箱子 \((x,y)\) 计算贡献。由于钥匙 \(\le 5\),所以所有 \((x,y)\) 最多是 \(5n\)

我们依次考虑每种颜色。但并不是所有的 \((x,y)\) 都满足条件,观察一下不难发现,如果我们将钥匙记为 \(+1\),箱子记为 \(-1\),其余记为 \(0\),那么树链 \(x\to y\) 的和为 \(0\) 是产生贡献的必要条件。

但是这样连样例都过不去,答案会偏大,显然有贡献重复统计了。比如 \(1,-1,1,-1\),其中 \((1,2)\)\((1,4)\) 被计算了两次。

如果数对 \((x,y)\) 在路径 \(x\to y\) 上有 \(z\) 使得 \((x,z)\) 产生贡献,那么 \(y\) 就不产生贡献。那么我们固定 \(x\),将所有可能的 \(y\) 按到 \(x\) 的距离排序,依次考虑,如果产生贡献就在 \(y\) 单点加,每次统计 \((x,y)\) 时先判断一下是否路径和为 \(0\)

可以直接用两个树链剖分维护,时间复杂度 \(\mathcal{O}(5n\log^2 n + m\log m)\),由于树剖和树状数组的常数都很小,可以无压力通过。

事实上,我们维护的是单点加,树链查询,直接树上差分树状数组维护,时间复杂度 \(\mathcal{O}((n+m)\log n)\)

树剖写的习惯,下面是树剖代码。

#define N 500005
int n, m, c[N], d[N], dfn[N], fa[N], top[N], sz[N], son[N], mat[N], ds[N], idx, w[N];
inline void add(int x,int y){for(; x <= n; x += x & -x)c[x] += y;}
inline int ask(int x){int sum = 0; for(; x; x -= x & -x)sum += c[x]; return sum;}
inline void Add(int x,int y){for(; x <= n; x += x & -x)w[x] += y;}
inline int Ask(int x){int sum = 0; for(; x; x -= x & -x)sum += w[x]; return sum;}
inline int get(int l,int r){return ask(r) - ask(l - 1);}
inline int Get(int l,int r){return Ask(r) - Ask(l - 1);}
vector<int>a[N], b[N], e[N]; vector<Pr> u[N], v[N];
void dfs(int x,int f){
	d[x] = d[fa[x] = f] + (sz[x] = 1);
	go(y, e[x])if(y != f){
		dfs(y, x), sz[x] += sz[y];
		if(sz[y] > sz[son[x]])son[x] = y;
	}
}
void dfs2(int x,int tp){
	top[x] = tp, mat[dfn[x] = ++idx] = x;
	if(!son[x])return;
	dfs2(son[x], tp);
	go(y, e[x])if(y != fa[x] && y != son[x])dfs2(y, y);
}
inline int lca(int x,int y){
	while(top[x] != top[y]){
		if(d[top[x]] < d[top[y]])swap(x, y);
		x = fa[top[x]];
	}
	return d[x] < d[y] ? x : y;
}
inline int dis(int x,int y){return d[x] + d[y] - d[lca(x, y)] * 2;}
int query(int x,int y){
	int sum = 0;
	while(top[x] != top[y]){
		if(d[top[x]] < d[top[y]])swap(x, y);
		sum += get(dfn[top[x]], dfn[x]), x = fa[top[x]];
	}
	if(d[x] > d[y])swap(x, y);
	sum += get(dfn[x], dfn[y]);
	return sum;
}
int Query(int x,int y){
	int sum = 0;
	while(top[x] != top[y]){
		if(d[top[x]] < d[top[y]])swap(x, y);
		sum += Get(dfn[top[x]], dfn[x]), x = fa[top[x]];
	}
	if(d[x] > d[y])swap(x, y);
	sum += Get(dfn[x], dfn[y]);
	return sum;
}
int calc(int x,int k){
	while(d[top[x]] > k)x = fa[top[x]];
	return mat[dfn[x] - d[x] + k];
}
#define ck(x, y) (dfn[x] <= dfn[y] && dfn[x] + sz[x] > dfn[y])
inline void ins(int l,int r,int L,int R){
	if(l > r || L > R)return;
	u[l].pb(mp(L, R)), v[r + 1].pb(mp(L, R));
}
void ins(int x,int y){  //  x  to  y
	if(ck(x, y)){
		int z = calc(y, d[x] + 1);
		ins(1, dfn[z] - 1, dfn[y], dfn[y] + sz[y] - 1);
		ins(dfn[z] + sz[z], n, dfn[y], dfn[y] + sz[y] - 1);
	}
	else if(ck(y, x)){
		int z = calc(x, d[y] + 1);
		ins(dfn[x], dfn[x] + sz[x] - 1, 1, dfn[z] - 1);
		ins(dfn[x], dfn[x] + sz[x] - 1, dfn[z] + sz[z], n);
	}
	else ins(dfn[x], dfn[x] + sz[x] - 1, dfn[y], dfn[y] + sz[y] - 1);
}
struct node{
	int l, r, id;
	bool operator<(const node o)const{return l < o.l;}
}q[N << 1];
int ed[N << 1];
int main() {
	read(n, m);
	rp(i, n){
		int x, y; read(x, y);
		if(1 == x)a[y].pb(i);
		else b[y].pb(i);
	}
	rp(i, n - 1){
		int x, y;
		read(x, y);
		e[x].pb(y), e[y].pb(x);
	}
	dfs(1, 0), dfs2(1, 1);
	rp(i, n)if(si(a[i]) && si(b[i])){
		go(x, a[i])add(dfn[x], 1);
		go(x, b[i])add(dfn[x], -1);
		go(x, a[i]){
			go(y, b[i])ds[y] = dis(x, y);
			sort(b[i].begin(), b[i].end(), [&](int x,int y){return ds[x] < ds[y];});
			vector<int>dl;
			go(y, b[i]){
				if(0 == query(x, y) && !Query(x, y)){
					Add(dfn[y], 1), dl.pb(y), ins(x, y);
				}
			}
			go(y, dl)Add(dfn[y], -1);
		}
		go(x, a[i])add(dfn[x], -1);
		go(x, b[i])add(dfn[x], 1);
	}
	memset(c, 0, sizeof(c));
	rp(i, m)read(q[i].l, q[i].r), q[i].id = i, q[i].l = dfn[q[i].l], q[i].r = dfn[q[i].r];
	sort(q + 1, q + m + 1);
	int j = 1;
	rp(i, n){
		go(x, u[i])add(x.fi, 1), add(x.se + 1, -1);
		go(x, v[i])add(x.fi, -1), add(x.se + 1, 1);
		while(j <= m && q[j].l == i)ed[q[j].id] = ask(q[j].r), j++;
	}
	rp(i, m)printf("%d\n", ed[i]);
	return 0;
}
posted @ 2022-05-10 18:45  7KByte  阅读(96)  评论(0编辑  收藏  举报