「解题报告」UOJ605 [UER #9] 知识网络
好像并不是很难的题?虽然从上午想到现在才开始写,还因为不知道 __builtin_popcount(x)
传入的是 int
调了一个多小时
题目就是要求一个全源最短路。直接求显然不太现实,考虑分析标签的性质。发现,同一标签内的所有点到某个点 \(u\) 的最短路的差值一定不超过 \(1\),因为同一标签下的点可以互相到达。那么,如果我们可以求出同一标签下的所有点到每个点的最短路,并且求出有多少点到这个点是最短路即可。
前者是好处理的,直接跑最短路就可以做到 \(O(km \log m)\)。由于是 01 最短路,跑 BFS 即可做到 \(O(km)\),Dijkstra 也能过。
后者我们可以建出最短路 DAG,那么我们就是要统计有多少个点能够在 DAG 上到达每个点。这东西看起来就不像是什么正经算法能做的东西,所以我们直接考虑 bitset
。我们只需要建出 DAG,然后拓扑排序并且跑 bitset
即可。但是这样复杂度就是 \(O(\frac{kn^2}{w})\) 了..吗?
实际上,每次的 bitset
大小只需要开标签的个数即可,而每个标签的数量总和显然是 \(O(n)\),那么我们可以手动实现一个不定长 bitset
,然后再跑,复杂度就是 \(O(\frac{n^2}{w})\) 了,可以通过。
然后你就被 Extra Test 创飞了。
发现空间限制 256MB,空间复杂度 \(O(\frac{n^2}{w})\) 是无法接受的。这时候,我们可以应用 毒瘤 lxl 在毒瘤分块题中卡空间的常见做法 来解决,即每 \(w\) 个数分一块,然后每次只计算每一块的答案,最后加起来即可。空间复杂度降至 \(O(n)\)。
然后记得使用 __builtin_popcountll
。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 510005;
int n, m, k;
int p[MAXN];
vector<pair<int, int>> e[MAXN];
const int B = 6, MASK = (1 << B) - 1;
vector<int> pt[MAXN];
long long ans[MAXN];
unsigned long long f[MAXN];
int dis[MAXN];
bool vis[MAXN];
struct DAG {
vector<int> e[MAXN];
int deg[MAXN], de[MAXN];
void add(int u, int v) {
e[u].push_back(v);
de[v]++;
}
void clear() {
for (int i = 1; i <= n + k; i++) {
e[i].clear();
de[i] = 0;
}
}
void topo() {
for (int i = 1; i <= n + k; i++) deg[i] = de[i];
queue<int> q;
for (int i = 1; i <= n + k; i++) if (!deg[i]) q.push(i);
while (!q.empty()) {
int u = q.front(); q.pop();
for (int v : e[u]) {
f[v] |= f[u];
deg[v]--;
if (!deg[v]) q.push(v);
}
}
}
} g;
int main() {
scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= n; i++) {
scanf("%d", &p[i]);
pt[p[i]].push_back(i);
e[n + p[i]].push_back({i, 1});
e[i].push_back({n + p[i], 0});
}
for (int i = 1; i <= m; i++) {
int u, v; scanf("%d%d", &u, &v);
e[u].push_back({v, 1});
e[v].push_back({u, 1});
}
for (int i = 1; i <= k; i++) {
int m = pt[i].size();
for (int j = 1; j <= n + k; j++) vis[j] = 0, dis[j] = INT_MAX / 2;
priority_queue<pair<int, int>, vector<pair<int, int>>, greater<pair<int, int>>> q;
for (int j : pt[i]) q.push({0, j}), dis[j] = 0;
while (!q.empty()) {
int u = q.top().second; q.pop();
if (vis[u]) continue;
vis[u] = 1;
for (auto p : e[u]) {
int v = p.first, w = p.second;
if (vis[v]) continue;
if (dis[v] > dis[u] + w) {
dis[v] = dis[u] + w;
q.push({dis[v], v});
}
}
}
g.clear();
for (int u = 1; u <= n + k; u++) {
for (auto p : e[u]) {
int v = p.first, w = p.second;
if (dis[v] == dis[u] + w) {
g.add(u, v);
}
}
}
for (int l = 0, r; l < m; l = r + 1) {
r = min(l + 63, m - 1);
for (int j = 1; j <= n + k; j++) f[j] = 0;
for (int j = l; j <= r; j++) f[pt[i][j]] |= 1ull << (j - l);
g.topo();
for (int j = 1; j <= n; j++) {
if (dis[j] < INT_MAX / 2) {
ans[dis[j]] += __builtin_popcountll(f[j]);
ans[dis[j] + 1] += (r - l + 1) - __builtin_popcountll(f[j]);
} else {
ans[2 * k] += (r - l + 1);
}
}
}
}
printf("0 ");
for (int i = 1; i <= 2 * k; i++) printf("%lld ", ans[i] / 2);
return 0;
}