洛谷P2680 运输计划(LCA + 二分 + 树上边差分)
洛谷P2680 运输计划
现在有一棵树,每条树边上都有正权值。接下来,有 m 个询问,每次询问给出两个结点,这两个结点之间有一条路径。现在你可以任选一条树边,将其边权置为0,请输出询问中的路径的最大值最小值是多少。
思路:
看到使最大值最小,就在向二分的想法贴。我们可以预处理出 m 个询问的每个询问的路径长度。想了想二分答案的话,那check函数就只能是\(O(m)\)吧。二分一下这个最大路径的最小值,遍历一遍这些预处理好的边长度,若长度大于二分的答案,那就是需要被删边的路径,我们算一下 m 里有 cnt 个路径是需要被删边的。对边进行差分(利用树的每个结点至多有一个父亲的性质,把父亲的边权归到自己作为点权),差分后求一个前缀和,枚举所有差分数组值等于 cnt 的边,取这些边的最大权值,若最大的查询路径长度减去这个最大边权后小于等于二分答案,那就是合法的。
代码
本题的二分做法会卡常,请使用链式前向星。
#include <bits/stdc++.h>
using namespace std;
const int N = 300010;
int n, m;
vector<array<int, 4> > vc; //x, y, lca, dis(x, y)
int dis[N], dep[N], val[N];
int anc[N][18];
int d[N];
int max_dis;
int nex[N << 1], e[N << 1], wei[N << 1], h[N], tot = 2;
void add(int a, int b, int c)
{
e[tot] = b;
wei[tot] = c;
nex[tot] = h[a];
h[a] = tot ++;1
}
void dfs2(int u, int fa)
{
for(int i = h[u]; i; i = nex[i])
{
int v = e[i];
if(v == fa) continue;
dfs2(v, u);
d[u] += d[v];
}
}
bool check(int mid)
{
memset(d, 0, sizeof d);
int res = 0;
int cnt = 0;
for(auto [x, y, lca, dist] : vc)
{
if(dist > mid)
{
d[x] ++, d[y] ++, d[lca] -= 2;
cnt ++;
}
}
dfs2(1, 0);
for(int i = 2; i <= n; i ++)
{
if(d[i] == cnt)
res = max(res, val[i]);
}
return max_dis - res <= mid;
}
void dfs1(int u, int fa)
{
dep[u] = dep[fa] + 1;
anc[u][0] = fa;
for(int i = 1; (1 << i) <= dep[u]; i ++)
anc[u][i] = anc[anc[u][i - 1]][i - 1];
for(int i = h[u]; i; i = nex[i])
{
int v = e[i], w = wei[i];
if(v == fa) continue;
val[v] = w;
dis[v] = dis[u] + w;
dfs1(v, u);
}
}
int LCA(int x, int y)
{
if(dep[x] < dep[y])
swap(x, y);
for(int i = 17; i >= 0; i --)
if(dep[x] - (1 << i) >= dep[y])
x = anc[x][i];
if(x == y)
return x;
for(int i = 17; i >= 0; i --)
if(anc[x][i] != anc[y][i])
x = anc[x][i], y = anc[y][i];
return anc[x][0];
}
void init()
{
for(auto &[x, y, lca, dist] : vc)
{
lca = LCA(x, y);
dist = dis[x] - 2 * dis[lca] + dis[y];
max_dis = max(max_dis, dist);
}
}
signed main()
{
read(n); read(m);
vc.resize(m);
for(int i = 1; i < n; i ++)
{
int x, y, c;
read(x); read(y); read(c);
add(x, y, c);
add(y, x, c);
}
dfs1(1, 0);
for(int i = 0; i < m; i ++)
{
int x, y;
read(x); read(y);
vc[i] = {x, y, 0, 0};
}
init();
int l = 0, r = max_dis;
while(l < r)
{
int mid = (l + r) >> 1;
if(check(mid))
r = mid;
else
l = mid + 1;
}
printf("%d\n", r);
}