题解:P10206 [JOI 2024 Final] 建设工程 2

涉及知识点:单源最短路。

解题思路

利用拆分的思想。

定义 \(dis_{0, i}\) 代表从 \(s\)\(i\) 的最短路,\(dis_{1, i}\) 代表从 \(t\)\(i\) 的最短路。

如果 \(s\)\(t\) 的最短路已经 \(\le k\) 了,所以无论怎么添加都可以满足条件,故答案为 \(n \times (n - 1) \div 2\)

否则,对于每个节点 \(i\),计算出 \(dis_{0, i} + dis_{1, j} + l \le k\) 的个数(\(j\) 为满足要求的节点),累加即可。

答案的计算可以直接用 upper_bound 实现。

代码

#include <bits/stdc++.h>
#define int long long
#define ll __int128
#define db double
#define ldb long double
#define vo void
#define endl '\n'
#define il inline
#define re register
#define ve vector
#define p_q priority_queue
#define PII pair<int, int>
#define u_m unordered_map
#define bt bitset

using namespace std;

//#define O2 1
#ifdef O2
	#pragma GCC optimize(1)
	#pragma GCC optimize(2)
	#pragma GCC optimize(3, "Ofast", "inline")
#endif

struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
	char buf[MAXSIZE], *p1, *p2;
	char pbuf[MAXSIZE], *pp;
	IO() : p1(buf), p2(buf), pp(pbuf) {}

	~IO() {
		fwrite(pbuf, 1, pp - pbuf, stdout);
	}
	char gc() {
		if (p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
		return p1 == p2 ? ' ' : *p1++;
	}

	bool blank(char ch) {
		return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';
	}

	template <class T>
	void read(T &x) {
		double tmp = 1;
		bool sign = 0;
		x = 0;
		char ch = gc();
		while (!isdigit(ch)) {
			if (ch == '-') sign = 1;
			ch = gc();
		}
		while (isdigit(ch)) {
			x = x * 10 + (ch - '0');
			ch = gc();
		}
		if (ch == '.') {
			ch = gc();
			while (isdigit(ch)) {
				tmp /= 10.0, x += tmp * (ch - '0');
				ch = gc();
			}
		}
		if (sign) x = -x;
	}

	void read(char *s) {
		char ch = gc();
		for (; blank(ch); ch = gc());
		for (; !blank(ch); ch = gc()) * s++ = ch;
		*s = 0;
	}

	void read(char &c) {
		for (c = gc(); blank(c); c = gc());
	}

	void push(const char &c) {
		if (pp - pbuf == MAXSIZE) fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
		*pp++ = c;
	}

	template <class T>
	void write(T x) {
		if (x < 0) x = -x, push('-');
		static T sta[35];
		T top = 0;
		do {
			sta[top++] = x % 10, x /= 10;
		} while (x);
		while (top) push(sta[--top] + '0');
	}

	template <class T>
	void write(T x, char lastChar) {
		write(x), push(lastChar);
	}
} io;

namespace COMB {
	int fact[200000];
	int Triangle[1010][1010];
	void Fact(int n, int mod) {
		fact[0] = 1;
		for (int i = 1; i <= n; i ++ ) fact[i] = ((fact[i - 1]) % mod * (i % mod)) % mod;
	}
	void Pascal_s_triangle(int n, int mod) {
		for (int i = 0; i <= n; i ++ ) Triangle[i][0] = 1;
		for (int i = 1; i <= n; i ++ )
			for (int j = 1; j <= i; j ++ )
				Triangle[i][j] = (Triangle[i - 1][j] + Triangle[i - 1][j - 1]) % mod;
	}
	int pw(int x, int y, int mod) {
		int res = 1;
		while (y) {
			if (y & 1) res = ((res % mod) * (x % mod)) % mod;
			x = (x % mod) * (x % mod) % mod;
			y >>= 1;
		}
		return res;
	}
	int pw(int x, int y) {
		int res = 1;
		while (y) {
			if (y & 1) res *= x;
			x *= x;
			y >>= 1;
		}
		return res;
	}
	int GCD(int x, int y, int mod) {
		return __gcd(x, y) % mod;
	}
	int LCM(int x, int y, int mod) {
		return (((x % mod) * (y % mod)) % mod / (GCD(x, y, mod) % mod)) % mod;
	}
	int C(int n, int m, int mod) {
		if (m > n || m < 0) return 0;
		return fact[n] * pw(fact[m], mod - 2, mod) % mod * pw(fact[n - m], mod - 2, mod) % mod;
	}
	int Ask_triangle(int x, int y) {
		return Triangle[x][y];
	}
}
using namespace COMB;

//#define fre 1
#define IOS 1
//#define multitest 1

const int N = 4e6 + 10;
const int M = 4e5 + 10;
const int inf = 1e17;
const int Mod = 1e9 + 9;

namespace zla {
	int n, m, s, t, l, k;
	ve<PII> g[N];
	int a[N];
	int ans;
	int dis[2][N], vis[2][N];
	struct NODE {
		int u, cnt;
		bool operator < (const NODE &other) const {
			return cnt < other.cnt;
		}
		bool operator > (const NODE &other) const {
			return cnt > other.cnt;
		}
	};
	p_q<NODE, ve<NODE>, greater<NODE> > Q;

	il void Init() {
		cin >> n >> m >> s >> t >> l >> k;
		while (m -- ) {
			int u, v, w;
			cin >> u >> v >> w;
			g[u].push_back(make_pair(v, w));
			g[v].push_back(make_pair(u, w));
		}
	}

	il void Solve() {
		for (int i = 1; i <= n; i ++ ) dis[0][i] = dis[1][i] = inf;
//		memset(dis, 0x3f, sizeof dis);
		dis[0][s] = 0;
		Q.push({s, 0});
		while (!Q.empty()) {
			NODE fr = Q.top();
			Q.pop();
			int u = fr.u;
			int cnt = fr.cnt;
			if (vis[0][u]) continue;
			vis[0][u] = 1;
			for (int i = 0; i < g[u].size(); i ++ ) {
				int v = g[u][i].first;
				int w = g[u][i].second;
				if (dis[0][v] > dis[0][u] + w) {
					dis[0][v] = dis[0][u] + w;
					Q.push({v, dis[0][v]});
				}
			}
		}
		dis[1][t] = 0;
		Q.push({t, 0});
		while (!Q.empty()) {
			NODE fr = Q.top();
			Q.pop();
			int u = fr.u;
			int cnt = fr.cnt;
			if (vis[1][u]) continue;
			vis[1][u] = 1;
			for (int i = 0; i < g[u].size(); i ++ ) {
				int v = g[u][i].first;
				int w = g[u][i].second;
				if (dis[1][v] > dis[1][u] + w) {
					dis[1][v] = dis[1][u] + w;
					Q.push({v, dis[1][v]});
				}
			}
		}
		for (int i = 1; i <= n; i ++ )
			if (dis[0][i] + dis[1][i] <= k) {
				cout << n * (n - 1) / 2;
				return ;
			}
		for (int i = 1; i <= n; i ++ ) a[i] = dis[1][i];
		sort (a + 1, a + n + 1);
		for (int i = 1; i <= n; i ++ ) ans += upper_bound(a + 1, a + n + 1, k - l - dis[0][i]) - a - 1;
		cout << ans;
	}

	il void main() {
		Init();
		Solve();
	}
}

signed main() {
	int T;
#ifdef IOS
	ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#endif
#ifdef fre
	freopen("build.in", "r", stdin);
	freopen("build.out", "w", stdout);
#endif
#ifdef multitest
	cin >> T;
#else
	T = 1;
#endif
	while (T--) zla::main();
	return 0;
}

posted @ 2024-10-25 14:41  zla_2012  阅读(3)  评论(0编辑  收藏  举报