JZOJ 6904. 【2020.11.28提高组模拟】T3 树上询问(query)
题目
你有一棵 \(n\) 节点的树 ,回答 \(m\) 个询问,每次询问给你两个整数 \(l,r\) ,问存在多少个整数 \(k\) 使得从 \(l\) 沿着 \(l \to r\) 的简单路径走 \(k\) 步恰好到达 \(k\) 。
分析
考虑离线后按链记贡献
从 \(l\) 到 \(lca(l,r)\) 这段链上,可以计入贡献的点 \(x\) 满足 \(dep[l]-x=dep[x]\),称为一类贡献
即 \(dep[x]+x=dep[l]\), 因为已知 \(dep[l]\),所以直接开桶计算
从 \(lca(l,r)\) 到 \(r\) 这段链上,可以计入贡献的点 \(x\) 满足 \(dep[lca]+(x-dep[l]-dep[lca])=dep[x]\),称为二类贡献
即 \(dep[x]-x=2\times dep[lca]-dep[l]\),同样可以直接开另一个桶计算
因为 \(dfs\) 下来时桶记录的是根到当前点的信息,所以算贡献的时候要减去 \(lca\) 处的假贡献
\(lca\) 也可能成为需要贡献,所以算二类贡献的时候减去 \(father_{lca}\) 处的贡献
具体细节体现在代码
\(Code\)
#include<cstdio>
#include<vector>
using namespace std;
const int N = 3e5 + 5;
int n, m, dep[N], d[2][2*N], fa[N], da[N], vis[N], l[N], r[N], lca[N], ans[N];
vector<int> e[N];
struct node1{int x, id;};
vector<node1> q1[N];
struct node2{int cs, ty, f, id;};
vector<node2> q2[N];
int find(int x){return fa[x] == x ? x : fa[x] = find(fa[x]);}
void dfs(int x, int dad)
{
da[x] = dad, dep[x] = dep[dad] + 1;
for(register int i = 0; i < e[x].size(); i++)
{
if (e[x][i] == dad) continue;
dfs(e[x][i], x);
}
}
void dfs1(int x, int dad)
{
vis[x] = 1;
for(register int i = 0; i < e[x].size(); i++)
{
if (e[x][i] == dad) continue;
dfs1(e[x][i], x), fa[e[x][i]] = x;
}
for(register int i = 0; i < q1[x].size(); i++)
if (vis[q1[x][i].x]) lca[q1[x][i].id] = find(q1[x][i].x);
}
void dfs2(int x, int dad)
{
++d[0][dep[x] + x], ++d[1][dep[x] - x + n];
for(register int i = 0; i < q2[x].size(); i++)
ans[q2[x][i].id] += q2[x][i].f * d[q2[x][i].ty][q2[x][i].cs];
for(register int i = 0; i < e[x].size(); i++)
{
if (e[x][i] == dad) continue;
dfs2(e[x][i], x);
}
--d[0][dep[x] + x], --d[1][dep[x] - x + n];
}
int main()
{
freopen("query.in" , "r" , stdin);
freopen("query.out" , "w" , stdout);
scanf("%d%d" , &n , &m);
int x , y;
for(register int i = 1; i < n; i++)
{
scanf("%d%d" , &x , &y);
e[x].push_back(y), e[y].push_back(x);
}
for(register int i = 1; i <= m; i++)
{
scanf("%d%d" , &l[i], &r[i]);
q1[l[i]].push_back(node1{r[i], i});
q1[r[i]].push_back(node1{l[i], i});
}
for(register int i = 1; i <= n; i++) fa[i] = i;
dfs(1, 0), dfs1(1, 0);
for(register int i = 1; i <= m; i++)
{
q2[l[i]].push_back(node2{dep[l[i]], 0, 1, i});
q2[lca[i]].push_back(node2{dep[l[i]], 0, -1, i});
q2[r[i]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, 1, i});
if (lca[i] > 1)
q2[da[lca[i]]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, -1, i});
}
dfs2(1, 0);
for(register int i = 1; i <= m; i++) printf("%d\n" , ans[i]);
}