[luogu2680] 运输计划
题面
很明显, 由于是求最长路的最小值, 我们可以使用二分求解. 我们二分一个长度\(mid\), 将所有使得\(dis(u, v)\)大于\(mid\)的点对\((u, v)\)找出, 设总共有\(m\)条这样的边, 那么我们需要改变的改变边一定是被这\(m\)条边都经过的边, 所以我们只需要找到满足这个要求的长度最大的改变边使得这\(m\)条边中最长的一条减去这一条改变边小于等于\(mid\) 就可以了, 如果没有这样的边就说明\(mid\)小了. 那要怎么找这条改变边呢??? 树上差分就可以了啊, 将边权下放为点权, 让\(cnt[u]++\), \(cnt[v]++\), \(cnt[lca(u, v)] -= 2\)就可以了, 至于某个点被经过的次数就将他子树的累加起来即可.
具体代码
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#define N 300005
using namespace std;
int n, m, head[N], tot, fa[N], cost[N], road[N], cnt, num[N], vis[N], l, r, mid, ans;
struct node
{
int to, cost, next;
} edge[N << 1];
struct Ask
{
int u, v, lca, dis;
} ask[N];
struct node2
{
int v, id;
};
vector<node2> LCA[N];
inline int read()
{
int x = 0, w = 1;
char c = getchar();
while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
return x * w;
}
inline void add(int u, int v, int w) { edge[++tot].to = v; edge[tot].cost = w; edge[tot].next = head[u]; head[u] = tot; }
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
void tarjan(int u, int ff)
{
vis[u] = 1;
for(int i = head[u]; i; i = edge[i].next)
{
int v = edge[i].to; if(v == ff) continue;
cost[v] = edge[i].cost; road[v] = road[u] + edge[i].cost;
tarjan(v, u); fa[v] = u;
}
for(int i = 0; i < (int) LCA[u].size(); i++)
if(vis[LCA[u][i].v] && !ask[LCA[u][i].id].lca)
ask[LCA[u][i].id].lca = find(LCA[u][i].v);
}
void dfs(int u, int fa)
{
for(int i = head[u]; i; i = edge[i].next)
{
int v = edge[i].to; if(v == fa) continue;
dfs(v, u); num[u] += num[v];
}
}
bool check(int mid)
{
memset(num, 0, sizeof(num));
int mx = -1, sum = 0;
for(int i = 1; i <= m; i++)
if(ask[i].dis > mid) { sum++; mx = max(mx, ask[i].dis); num[ask[i].u]++; num[ask[i].v]++; num[ask[i].lca] -= 2; }
dfs(1, 0);
for(int i = 1; i <= n; i++)
if(num[i] >= sum && mx - cost[i] <= mid) return 1;
return 0;
}
int main()
{
n = read(); m = read();
for(int i = 1; i <= n; i++) fa[i] = i;
for(int i = 1; i < n; i++)
{
int u = read(), v = read(), w = read();
add(u, v, w); add(v, u, w);
}
for(int i = 1; i <= m; i++)
{
ask[i].u = read(); ask[i].v = read();
LCA[ask[i].u].push_back((node2){ ask[i].v, i });
LCA[ask[i].v].push_back((node2){ ask[i].u, i });
}
tarjan(1, 0);
for(int i = 1; i <= m; i++)
{
r = max(r, road[ask[i].u] + road[ask[i].v] - road[ask[i].lca] * 2);
ask[i].dis = road[ask[i].u] + road[ask[i].v] - road[ask[i].lca] * 2;
}
ans = r;
while(l <= r)
{
mid = (l + r) >> 1;
if(check(mid)) { ans = mid; r = mid - 1; }
else l = mid + 1;
}
printf("%d\n", ans);
return 0;
}