CF 1307D Cow and Fields
题意:给一张联通无向图,边权全 1,其中有 k 个关键点,求选取一对关键点增加一条边之后 1~n 的最短路的最大值
赛中想的过于复杂,以为要考虑最短路中存在 0 / 1 / 2+
个关键点时的情况,用最短路树去做 dp,但是发现因为最短路树不唯一没办法解决
其实考虑 \(a_x\) 表示 1~x 的最短路,\(b_y\) 表示 y~n 的最短路, x,y
都是关键点,则有 \(res=max(min(a_x+b_y+1,a_y+b_x+1,tmp))\),其中 \(tmp\) 表示不加边时的最短路
而 \(x,y\) 的顺序其实是可以确定的,对于任意不同的 \(x,y\),如果 \(a_x-b_x \le a_y-b_y\) 则 \(a_x+b_y+1 \le a_y+b_x+1\)。所以做一个排序求前缀最值即可
代码:
#include <bits/stdc++.h>
#define ll long long
#define X first
#define Y second
#define sz size()
#define all(x) x.begin(), x.end()
using namespace std;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<long long> vl;
template <class T>
inline bool scan(T &ret){
char c;
int sgn;
if (c = getchar(), c == EOF) return 0; //EOF
while (c != '-' && (c < '0' || c > '9')) c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
ret *= sgn;
return 1;
}
const ll mod = 1e9+7;
const int maxn = 2e5+50;
const int inf = 0x3f3f3f3f;
const double eps = 1e-8;
ll qp(ll x, ll n) {
ll res = 1; x %= mod;
while (n > 0) {
if (n & 1) res = res * x % mod;
x = x * x % mod;
n >>= 1;
}
return res;
}
int n, m, k;
vi edge[maxn];
int dist[2][maxn];
void bfs(int flag, int st) {
queue<int> q;
q.push(st);
while (!q.empty()) {
int x = q.front();
q.pop();
for (auto i : edge[x]) {
if (i != st && !dist[flag][i]) {
dist[flag][i] = dist[flag][x] + 1;
q.push(i);
}
}
}
}
int main(int argc, char* argv[]) {
scanf("%d%d%d", &n, &m, &k);
vi a;
for (int i = 1, x; i <= k; ++i) {
scanf("%d", &x);
a.push_back(x);
}
for (int i = 1, u, v; i <= m; ++i) {
scanf("%d%d", &u, &v);
edge[u].push_back(v);
edge[v].push_back(u);
}
bfs(0, 1);
bfs(1, n);
// max(min(dist[0][x] + dist[1][y] + 1, dist[0][y] + dist[1][x] + 1))
sort(a.begin(), a.end(), [](int x, int y) {
return dist[0][x] - dist[1][x] < dist[0][y] - dist[1][y];
});
int mx = a[0];
int res = 0;
for (int i = 1; i < a.size(); ++i) {
res = max(res, min(dist[0][n], dist[0][mx] + dist[1][a[i]] + 1));
// cerr << mx << " " << a[i] << endl;
if (dist[0][mx] < dist[0][a[i]]) mx = a[i];
}
printf("%d\n", res);
return 0;
}