JZOJ 6904. 【2020.11.28提高组模拟】T3 树上询问(query)

题目

你有一棵 \(n\) 节点的树 ,回答 \(m\) 个询问,每次询问给你两个整数 \(l,r\) ,问存在多少个整数 \(k\) 使得从 \(l\) 沿着 \(l \to r\) 的简单路径走 \(k\) 步恰好到达 \(k\)

分析

考虑离线后按链记贡献
\(l\)\(lca(l,r)\) 这段链上,可以计入贡献的点 \(x\) 满足 \(dep[l]-x=dep[x]\),称为一类贡献
\(dep[x]+x=dep[l]\), 因为已知 \(dep[l]\),所以直接开桶计算
\(lca(l,r)\)\(r\) 这段链上,可以计入贡献的点 \(x\) 满足 \(dep[lca]+(x-dep[l]-dep[lca])=dep[x]\),称为二类贡献
\(dep[x]-x=2\times dep[lca]-dep[l]\),同样可以直接开另一个桶计算
因为 \(dfs\) 下来时桶记录的是根到当前点的信息,所以算贡献的时候要减去 \(lca\) 处的假贡献
\(lca\) 也可能成为需要贡献,所以算二类贡献的时候减去 \(father_{lca}\) 处的贡献
具体细节体现在代码

\(Code\)

#include<cstdio>
#include<vector>
using namespace std;

const int N = 3e5 + 5;
int n, m, dep[N], d[2][2*N], fa[N], da[N], vis[N], l[N], r[N], lca[N], ans[N];
vector<int> e[N];
struct node1{int x, id;};
vector<node1> q1[N];
struct node2{int cs, ty, f, id;};
vector<node2> q2[N];

int find(int x){return fa[x] == x ? x : fa[x] = find(fa[x]);}
void dfs(int x, int dad)
{
	da[x] = dad, dep[x] = dep[dad] + 1;
	for(register int i = 0; i < e[x].size(); i++)
	{
		if (e[x][i] == dad) continue;
		dfs(e[x][i], x);
	}
}
void dfs1(int x, int dad)
{
	vis[x] = 1;
	for(register int i = 0; i < e[x].size(); i++)
	{
		if (e[x][i] == dad) continue;
		dfs1(e[x][i], x), fa[e[x][i]] = x;
	}
	for(register int i = 0; i < q1[x].size(); i++)
		if (vis[q1[x][i].x]) lca[q1[x][i].id] = find(q1[x][i].x);
}
void dfs2(int x, int dad)
{
	++d[0][dep[x] + x], ++d[1][dep[x] - x + n];
	for(register int i = 0; i < q2[x].size(); i++)
		ans[q2[x][i].id] += q2[x][i].f * d[q2[x][i].ty][q2[x][i].cs];
	for(register int i = 0; i < e[x].size(); i++)
	{
		if (e[x][i] == dad) continue;
		dfs2(e[x][i], x);
	}
	--d[0][dep[x] + x], --d[1][dep[x] - x + n];
}

int main()
{
	freopen("query.in" , "r" , stdin);
	freopen("query.out" , "w" , stdout);
	scanf("%d%d" , &n , &m);
	int x , y;
	for(register int i = 1; i < n; i++)
	{
		scanf("%d%d" , &x , &y);
		e[x].push_back(y), e[y].push_back(x);
	}
	for(register int i = 1; i <= m; i++)
	{
		scanf("%d%d" , &l[i], &r[i]);
		q1[l[i]].push_back(node1{r[i], i});
		q1[r[i]].push_back(node1{l[i], i});
	}
	for(register int i = 1; i <= n; i++) fa[i] = i;
	dfs(1, 0), dfs1(1, 0);
	for(register int i = 1; i <= m; i++)
	{
		q2[l[i]].push_back(node2{dep[l[i]], 0, 1, i}); 
		q2[lca[i]].push_back(node2{dep[l[i]], 0, -1, i});
		q2[r[i]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, 1, i});
		if (lca[i] > 1) 
			q2[da[lca[i]]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, -1, i});
	}
	dfs2(1, 0);
	for(register int i = 1; i <= m; i++) printf("%d\n" , ans[i]);
}
posted @ 2020-11-29 18:14  leiyuanze  阅读(184)  评论(0编辑  收藏  举报