题解:P10206 [JOI 2024 Final] 建设工程 2
涉及知识点:单源最短路。
解题思路
利用拆分的思想。
定义 \(dis_{0, i}\) 代表从 \(s\) 到 \(i\) 的最短路,\(dis_{1, i}\) 代表从 \(t\) 到 \(i\) 的最短路。
如果 \(s\) 到 \(t\) 的最短路已经 \(\le k\) 了,所以无论怎么添加都可以满足条件,故答案为 \(n \times (n - 1) \div 2\)。
否则,对于每个节点 \(i\),计算出 \(dis_{0, i} + dis_{1, j} + l \le k\) 的个数(\(j\) 为满足要求的节点),累加即可。
答案的计算可以直接用 upper_bound
实现。
代码
#include <bits/stdc++.h>
#define int long long
#define ll __int128
#define db double
#define ldb long double
#define vo void
#define endl '\n'
#define il inline
#define re register
#define ve vector
#define p_q priority_queue
#define PII pair<int, int>
#define u_m unordered_map
#define bt bitset
using namespace std;
//#define O2 1
#ifdef O2
#pragma GCC optimize(1)
#pragma GCC optimize(2)
#pragma GCC optimize(3, "Ofast", "inline")
#endif
struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
char buf[MAXSIZE], *p1, *p2;
char pbuf[MAXSIZE], *pp;
IO() : p1(buf), p2(buf), pp(pbuf) {}
~IO() {
fwrite(pbuf, 1, pp - pbuf, stdout);
}
char gc() {
if (p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
return p1 == p2 ? ' ' : *p1++;
}
bool blank(char ch) {
return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';
}
template <class T>
void read(T &x) {
double tmp = 1;
bool sign = 0;
x = 0;
char ch = gc();
while (!isdigit(ch)) {
if (ch == '-') sign = 1;
ch = gc();
}
while (isdigit(ch)) {
x = x * 10 + (ch - '0');
ch = gc();
}
if (ch == '.') {
ch = gc();
while (isdigit(ch)) {
tmp /= 10.0, x += tmp * (ch - '0');
ch = gc();
}
}
if (sign) x = -x;
}
void read(char *s) {
char ch = gc();
for (; blank(ch); ch = gc());
for (; !blank(ch); ch = gc()) * s++ = ch;
*s = 0;
}
void read(char &c) {
for (c = gc(); blank(c); c = gc());
}
void push(const char &c) {
if (pp - pbuf == MAXSIZE) fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
*pp++ = c;
}
template <class T>
void write(T x) {
if (x < 0) x = -x, push('-');
static T sta[35];
T top = 0;
do {
sta[top++] = x % 10, x /= 10;
} while (x);
while (top) push(sta[--top] + '0');
}
template <class T>
void write(T x, char lastChar) {
write(x), push(lastChar);
}
} io;
namespace COMB {
int fact[200000];
int Triangle[1010][1010];
void Fact(int n, int mod) {
fact[0] = 1;
for (int i = 1; i <= n; i ++ ) fact[i] = ((fact[i - 1]) % mod * (i % mod)) % mod;
}
void Pascal_s_triangle(int n, int mod) {
for (int i = 0; i <= n; i ++ ) Triangle[i][0] = 1;
for (int i = 1; i <= n; i ++ )
for (int j = 1; j <= i; j ++ )
Triangle[i][j] = (Triangle[i - 1][j] + Triangle[i - 1][j - 1]) % mod;
}
int pw(int x, int y, int mod) {
int res = 1;
while (y) {
if (y & 1) res = ((res % mod) * (x % mod)) % mod;
x = (x % mod) * (x % mod) % mod;
y >>= 1;
}
return res;
}
int pw(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res *= x;
x *= x;
y >>= 1;
}
return res;
}
int GCD(int x, int y, int mod) {
return __gcd(x, y) % mod;
}
int LCM(int x, int y, int mod) {
return (((x % mod) * (y % mod)) % mod / (GCD(x, y, mod) % mod)) % mod;
}
int C(int n, int m, int mod) {
if (m > n || m < 0) return 0;
return fact[n] * pw(fact[m], mod - 2, mod) % mod * pw(fact[n - m], mod - 2, mod) % mod;
}
int Ask_triangle(int x, int y) {
return Triangle[x][y];
}
}
using namespace COMB;
//#define fre 1
#define IOS 1
//#define multitest 1
const int N = 4e6 + 10;
const int M = 4e5 + 10;
const int inf = 1e17;
const int Mod = 1e9 + 9;
namespace zla {
int n, m, s, t, l, k;
ve<PII> g[N];
int a[N];
int ans;
int dis[2][N], vis[2][N];
struct NODE {
int u, cnt;
bool operator < (const NODE &other) const {
return cnt < other.cnt;
}
bool operator > (const NODE &other) const {
return cnt > other.cnt;
}
};
p_q<NODE, ve<NODE>, greater<NODE> > Q;
il void Init() {
cin >> n >> m >> s >> t >> l >> k;
while (m -- ) {
int u, v, w;
cin >> u >> v >> w;
g[u].push_back(make_pair(v, w));
g[v].push_back(make_pair(u, w));
}
}
il void Solve() {
for (int i = 1; i <= n; i ++ ) dis[0][i] = dis[1][i] = inf;
// memset(dis, 0x3f, sizeof dis);
dis[0][s] = 0;
Q.push({s, 0});
while (!Q.empty()) {
NODE fr = Q.top();
Q.pop();
int u = fr.u;
int cnt = fr.cnt;
if (vis[0][u]) continue;
vis[0][u] = 1;
for (int i = 0; i < g[u].size(); i ++ ) {
int v = g[u][i].first;
int w = g[u][i].second;
if (dis[0][v] > dis[0][u] + w) {
dis[0][v] = dis[0][u] + w;
Q.push({v, dis[0][v]});
}
}
}
dis[1][t] = 0;
Q.push({t, 0});
while (!Q.empty()) {
NODE fr = Q.top();
Q.pop();
int u = fr.u;
int cnt = fr.cnt;
if (vis[1][u]) continue;
vis[1][u] = 1;
for (int i = 0; i < g[u].size(); i ++ ) {
int v = g[u][i].first;
int w = g[u][i].second;
if (dis[1][v] > dis[1][u] + w) {
dis[1][v] = dis[1][u] + w;
Q.push({v, dis[1][v]});
}
}
}
for (int i = 1; i <= n; i ++ )
if (dis[0][i] + dis[1][i] <= k) {
cout << n * (n - 1) / 2;
return ;
}
for (int i = 1; i <= n; i ++ ) a[i] = dis[1][i];
sort (a + 1, a + n + 1);
for (int i = 1; i <= n; i ++ ) ans += upper_bound(a + 1, a + n + 1, k - l - dis[0][i]) - a - 1;
cout << ans;
}
il void main() {
Init();
Solve();
}
}
signed main() {
int T;
#ifdef IOS
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#endif
#ifdef fre
freopen("build.in", "r", stdin);
freopen("build.out", "w", stdout);
#endif
#ifdef multitest
cin >> T;
#else
T = 1;
#endif
while (T--) zla::main();
return 0;
}