「WOJ 4701」Walk 虚点建图+01bfs
前言
模拟赛中,yzh遇到了这道题,但由于yzh没有学过01bfs。。。所以就一直在优化这道题,导致浪费了很长时间 (但nfls的数据太水,dij和spfa都能过)
Description
给你一个 \(n\) 个点,\(m\) 条边的图,经过每条边需要消耗 \(1\) 的时间。
每个点还会有一个权值 \(val_i\) 。使得
任意两个房间 \(u,v\) 可以消耗 \(1\) 的时间传送,当且仅当 \(val_u \&val_v=val_v\)
其中 \(\&\) 为与运算。
现在,你站在 \(1\) 点,你需要求出从 \(1\) 号点到达每一个点,所需要的最短时间。
输入格式
第一行 \(n,m\)。
第二行 \(n\) 个数,代表 \(val_1,val_2,\cdots,val_n\)。
接下来 \(m\) 行,每行 \(u,v\) 表示一条从 \(u\) 到 \(v\) 的有向边。
输出格式
输出 \(n\) 行。第 \(i\) 行输出从 \(1\) 到 \(i\) 所消耗的最小时间。
数据范围
Solution
暴力
有一种暴力的做法,我们可以暴力枚举每两个节点的 \(val\)。并对其进行连边,然后跑一遍bfs。这样做边的数量会非常多,显然不能拿全分。
优化建边
考虑优化建边,不难想到,建若干个虚点 \(x\) ,我们只需要将 \(val=x\) 的节点向 \(x\) 连一条长度为 \(0\) 的边即可,然后对于两个不同的虚点,若满足 \(x\&x'=x'\) 那么虚点 \(x\) 向 \(x'\) 连一条长度为 \(0\) 的有相边。这样做显然还是有很多边因为 \(1\le val_i\le 2^{30}\)。
考虑继续优化建边,类似于高维前缀和的做法,对于两个虚点 \(x\) 和 \(x'\),满足 \(x\oplus x'=2^k(k\epsilon \mathbb N)\) ,那么 \(x\) 向 \(x'\) 连一条长度为 \(0\) 的有向边,可以把这种优化建边理解为一种树形结构。比如 \((111)_2\) 向 \((110)_2\) 连边,\((110)_2\) 向 \((100)_2\) 连边,那么就相当于 \((111)_2\) 到 \((100)_2\) 连了一条长度为 \(0\) 的边。
求最短路——01bfs
建完边考虑如何求最短路,由于边长既有 \(1\) 也有 \(0\) 所以普通的bfs是肯定不行的,我们可以用dij进行求最短路,但可能会不过了 (但也能卡过),对于虚点建图显然是一个较为稀疏的图,spfa实测跑的比dij快。正确解法是01bfs,具体的,队列队头的点,将它延申出去的边中长度为 \(0\) 的边的结束点先放入队头,再将长度为 \(1\) 的放入队尾,显然的,这样就能够使得求出来的长度最小了,这个01bfs的过程可以用deque实现
CODE
01bfs 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int M = 998244353, N = (1 << 20) + (1000000) + 1000, T = (1 << 20), inf = 0x3f3f3f3f;
int val[N];
struct node {
int id, v;
};
vector<node> a[N];
int lowbit(int x) { return x & (-x); }
int dis[N];
deque<node> q;
bool vis[N];
int main() {
int n, m;
scanf("%d%d", &n, &m);
int mx = 0;
for (int i = 1; i <= n; i++) {
scanf("%d", &val[i]);
mx = max(mx, val[i]);
a[val[i]].emplace_back(node{ i + T, 0 });
a[i + T].emplace_back(node{ val[i], 1 });
}
for (int i = 1; i <= m; i++) {
int u, v;
scanf("%d%d", &u, &v);
a[u + (T)].emplace_back(node{ v + T, 1 });
// a[v + T].emplace_back(node{u + T , 1});
}
for (int i = 1; i < (T); i++) {
int t = i;
for (; t >= 1; t -= lowbit(t)) {
if (i - lowbit(t) == 0)
break;
a[i].emplace_back(node{ i - lowbit(t), 0 });
}
}
memset(dis, 0x3f, sizeof(dis));
q.push_front(node{ 1 + T, 0 });
dis[1 + T] = 0;
while (!q.empty()) {
node u = q.front();
q.pop_front();
if (vis[u.id])
continue;
vis[u.id] = true;
for (auto v : a[u.id]) {
if (vis[v.id])
continue;
if (v.v == 0)
q.push_front(node{ v.id, u.v }), dis[v.id] = min(dis[v.id], u.v);
if (v.v == 1)
q.push_back(node{ v.id, u.v + 1 }), dis[v.id] = min(dis[v.id], u.v + 1);
}
}
for (int i = 1 + T; i <= n + T; i++) {
if (dis[i] == inf)
puts("-1");
else
printf("%d\n", dis[i]);
}
return 0;
}
dij代码
#include <bits/stdc++.h>
using namespace std;
#define pii pair<int, int>
const int inf = 0x3f3f3f3f;
const int N = 2e5 + 7;
const int MAXN = (1 << 20);
int n, m;
int valt[N];
vector<int> edge[N];
vector<int> node[MAXN + 7];
int dist[N + MAXN];
bool vis[N + MAXN];
priority_queue<pii, vector<pii>, greater<pii> > qu;
// int pre[N+MAXN];
void dij() {
memset(dist, inf, sizeof(dist));
dist[1] = 0;
qu.push(make_pair(0, 1));
while (!qu.empty()) {
int u = qu.top().second;
// cout<<u<<endl;
int minn = qu.top().first;
qu.pop();
if (vis[u])
continue;
vis[u] = true;
if (u <= n) {
for (int i = 0; i < edge[u].size(); ++i) { //常规m条边
int v = edge[u][i];
if (dist[v] > minn + 1 && !vis[v]) {
dist[v] = minn + 1;
qu.push(make_pair(dist[v], v));
}
}
if (!vis[valt[u] + n]) {
qu.push(make_pair(minn, valt[u] + n));
// if(valt[u]==6655)
// cout<<u<<endl;
}
continue;
}
u -= n;
for (int i = 0; i < node[u].size(); ++i) { //点权为u的点
int v = node[u][i];
if (dist[v] > minn + 1 && !vis[v]) {
dist[v] = minn + 1;
qu.push(make_pair(dist[v], v));
// if(v==3) cout<<u<<"->"<<dist[v]<<endl;
}
}
for (int i = 0; i <= 20; ++i) { //虚点之间连边
if (((u >> i) & 1)) {
int v = (u ^ (1 << i));
if (v == 0)
continue;
if (dist[v + n] > minn && !vis[v + n]) {
dist[v + n] = minn;
qu.push(make_pair(minn, v + n));
// if(v==6655){
// cout<<u<<";;;\n";
// }
}
}
}
}
}
signed main() {
clock_t st, ed;
st = clock();
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%d", &valt[i]);
node[valt[i]].push_back(i);
}
for (int i = 1; i <= m; ++i) {
int u, v;
scanf("%d%d", &u, &v);
edge[u].push_back(v);
}
dij();
for (int i = 1; i <= n; ++i) {
if (dist[i] == inf) {
puts("-1");
continue;
}
printf("%d\n", dist[i]);
}
ed = clock();
// cout<<endl<<"run:"<<ed-st<<"ms"<<endl;
return 0;
}