CF 1307D Cow and Fields

Cow and Fields

题意:给一张联通无向图,边权全 1,其中有 k 个关键点,求选取一对关键点增加一条边之后 1~n 的最短路的最大值

赛中想的过于复杂,以为要考虑最短路中存在 0 / 1 / 2+ 个关键点时的情况,用最短路树去做 dp,但是发现因为最短路树不唯一没办法解决

其实考虑 \(a_x\) 表示 1~x 的最短路,\(b_y\) 表示 y~n 的最短路, x,y 都是关键点,则有 \(res=max(min(a_x+b_y+1,a_y+b_x+1,tmp))\),其中 \(tmp\) 表示不加边时的最短路

\(x,y\) 的顺序其实是可以确定的,对于任意不同的 \(x,y\),如果 \(a_x-b_x \le a_y-b_y\)\(a_x+b_y+1 \le a_y+b_x+1\)。所以做一个排序求前缀最值即可

代码:

#include <bits/stdc++.h>
#define ll long long
#define X first
#define Y second
#define sz size()
#define all(x) x.begin(), x.end()
using namespace std;

typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<long long> vl;

template <class T>
inline bool scan(T &ret){
    char c;
    int sgn;
    if (c = getchar(), c == EOF) return 0; //EOF
    while (c != '-' && (c < '0' || c > '9')) c = getchar();
    sgn = (c == '-') ? -1 : 1;
    ret = (c == '-') ? 0 : (c - '0');
    while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
    ret *= sgn;
    return 1;
}

const ll mod = 1e9+7;
const int maxn = 2e5+50;
const int inf = 0x3f3f3f3f;
const double eps = 1e-8;

ll qp(ll x, ll n) {
    ll res = 1; x %= mod;
    while (n > 0) {
        if (n & 1) res = res * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return res;
}

int n, m, k;
vi edge[maxn];
int dist[2][maxn];

void bfs(int flag, int st) {
    queue<int> q;
    q.push(st);
    while (!q.empty()) {
        int x = q.front();
        q.pop();
        for (auto i : edge[x]) {
            if (i != st && !dist[flag][i]) {
                dist[flag][i] = dist[flag][x] + 1;
                q.push(i);
            }
        }
    }
}

int main(int argc, char* argv[]) {
    scanf("%d%d%d", &n, &m, &k);
    vi a;
    for (int i = 1, x; i <= k; ++i) {
        scanf("%d", &x);
        a.push_back(x);
    }
    for (int i = 1, u, v; i <= m; ++i) { 
        scanf("%d%d", &u, &v);
        edge[u].push_back(v);
        edge[v].push_back(u);
    }
    bfs(0, 1);
    bfs(1, n);
    // max(min(dist[0][x] + dist[1][y] + 1, dist[0][y] + dist[1][x] + 1))
    sort(a.begin(), a.end(), [](int x, int y) {
        return dist[0][x] - dist[1][x] < dist[0][y] - dist[1][y];
    });
    int mx = a[0];
    int res = 0;
    for (int i = 1; i < a.size(); ++i) {
        res = max(res, min(dist[0][n], dist[0][mx] + dist[1][a[i]] + 1));
        // cerr << mx << " " << a[i] << endl;
        if (dist[0][mx] < dist[0][a[i]]) mx = a[i];
    }
    printf("%d\n", res);
    return 0;
}
posted @ 2020-02-19 17:10  badcw  阅读(181)  评论(0编辑  收藏  举报