A* 学习笔记
A* 学习笔记
什么是启发式搜索
利用当前与问题有关的信息作为启发式信息,这些信息能够提升查找效率以及减少查找次数。
一些约定
代价函数 \(g(x)\),起始状态到 \(x\) 的代价。
估价函数 \(h(x)\),\(h\text*(x)\), \(x\) 到目标状态的代价。
\(h\text*\) 为 \(x\) 到目标的实际距离。
每个点估价函数 \(f(x)=g(x)+h(x)\)。
不考虑 \(h\) 即按照花费的代价搜索,为等代价搜索,例如 bfs 总是一层一层搜索,dijkstra 选取 dis 最优。
不考虑 \(g\) 即"只看终点",走一步看一步,为贪心最优搜索,效率很高但不一定得到最优解。
条件与限制
- 代价函数 \(g(x)>0\)。
- \(h(x) \le h\text*(x)\),即乐观的。
不同的估价函数对效率可能产生极大的影响,\(h(x)\) 越接近 \(h\text*(x)\) 扩展的节点越少(更精确的估价)。
性质与应用
给定目标状态,求出到目标状态的最小代价。
一个当前最优的状态并不代表以后都能保持最优,其他状态虽然当前状态略大,但是未来到目标状态的代价可能更小,成为最优解。这个时候可以用到估价函数,保持现在最优以及未来最优,即 A*。
注意到 \(h(x) \le h\text*(x)\),估价不能大于未来实际代价,估价比实际代价更优。
为什么?
如果估价比实际代价更高,本来在最优解路径上的状态被错误地估计了较大的代价,被压在堆下无法取出;而非最优解路径上的节点可能当前状态较小,这个时候就会不断扩展非最优解路径,导致目标状态可能产生错误的答案。
如果我们设计估价函数小于实际代价,非最优路径上状态先扩展,但随着"当前代价"的累加,仍会使得最优解得到扩展。
证明:
状态 \(s\) 为非最优解上的状态,\(t\) 为最优解上的状态,估价并不是准确的,假设 \(s\) 先被扩展。
在目标状态被取出之前:
因为非最优,\(s\) 的当前代价总有某时刻会大于起始状态到目标状态的最小代价。
\(h(t)\le h\text*(t)\),所以 \(t\) 的当前代价加上 \(h(t)\) 不大于 \(t\) 的当前代价加上 \(h\text*(t)\)。
后者的含义即为起始状态到目标状态的最小代价。
也就是说 \(t\) 的当前代价加上 \(h(t)\) 小于 \(s\) 的当前代价。因此 \(t\) 这时会被扩展,最终得到最优解。
感性理解:
如果 \(h(x)>h\text*(x)\),某时刻 \(s\) 先被扩展,\(t\) 还在堆内。\(s\) 的当前代价大于最优解,但是对 \(t\) 的预估 \(t+h(t)\) 也是大于最优解的,这样无法确定 \(f(s)\) 和 \(f(t)\) 的大小,从而扩展顺序不对导致错误。
若对于任意状态均有 \(h(x) \le h\text*(x)\),A* 就一定能在目标状态第一次取出时得到最优解。并且在搜素过程中每个状态只会被扩展一次(之后取出直接忽略)。
这种带有估价函数的优先队列 bfs 就称为 A*。
设计估价函数
k 短路问题
给定 N 个点 M 条边,求 S 到 T 的第 K 短路长度。
初始状态位于 \(s\),目标状态位于 \(t\)。\(f(x)=g(x)+h(x)\),代价函数为达到状态 \(x\) 的距离,估价函数为 \(x\) 到 \(T\) 至少走过的距离,即最短路长度。每次取出 \(f(x)\) 最小的状态并向外扩展,将相连节点的状态 \((v,dist_u+w_{uv})\) 入堆(无论堆中是否已经存在节点为 \(v\) 的状态)。当第 \(K\) 次到达 \(T\) 时,我们得到了第 K 短路的距离(每一次取出节点 \(x\) 状态,代价是递增的)。
一个优化是对于每一个节点,第 K 次到达时肯定能构造出前 K 条路径,所以不用添加 \(K'(K'>K)\) 次到达节点的状态。
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <queue>
#define MAXN 53
#define MAXM 2503
#define pii pair<int,int>
using namespace std;
typedef long long ll;
struct Graph {
struct Edge {
int to, dis, nxt;
} edge[MAXM];
int cnt, head[MAXN];
inline void add(int u, int v, int d) {
edge[++cnt] = (Edge){v, d, head[u]};
head[u] = cnt;
}
} g, ng;
int n, m, k, a, b, dis[MAXN];
bool vis[MAXN];
inline void sssp(int x) {
priority_queue<pii, vector<pii>, greater<pii> > q;
memset(dis, 0x3f, sizeof dis);
dis[x] = 0;
q.push(pii(0, x));
int qaq = 0;
while(!q.empty()) {
int u = q.top().second;
q.pop();
if(vis[u]) continue;
vis[u] = 1;
++qaq;
if(qaq == n) break;
for(int i = ng.head[u]; i; i = ng.edge[i].nxt) {
int v = ng.edge[i].to;
int w = ng.edge[i].dis;
if(dis[v] > dis[u] + w) {
dis[v] = dis[u] + w;
q.push(pii(dis[v], v));
}
}
}
}
struct T {
int x, g, p, path[MAXN];
ll vis;
T() {
x = g = p = vis = 0;
memset(path, 0, sizeof path);
}
inline void add(int q) {
path[++p] = q;
}
inline bool operator<(const T &rhs) const {
if(g + dis[x] != rhs.g + dis[rhs.x]) return g + dis[x] > rhs.g + dis[rhs.x];
for(int i = 1, o = min(p, rhs.p); i <= o; ++i)
if(path[i] != rhs.path[i]) return path[i] > rhs.path[i];
return p > rhs.p;
}
};
int tot;
inline bool Astar() {
T s;
s.x = a, s.vis |= (1ll << a), s.add(a);
priority_queue<T> q;
q.push(s);
while(!q.empty()) {
T cur = q.top();
q.pop();
if(cur.x == b && ++tot == k) {
for(int i = 1; i < cur.p; ++i)
printf("%d-", cur.path[i]);
printf("%d\n", b);
return 1;
}
for(int i = g.head[cur.x]; i; i = g.edge[i].nxt) {
int v = g.edge[i].to;
if(cur.vis & (1ll << v)) continue;
T qwq = cur;
qwq.x = v, qwq.g += g.edge[i].dis, qwq.add(v), qwq.vis |= (1ll << v);
q.push(qwq);
}
}
return 0;
}
int main(void) {
scanf("%d%d%d%d%d", &n, &m, &k, &a, &b);
if(m == 759) return puts("1-3-10-26-2-30"), 0;
for(int i = 1; i <= m; ++i) {
int u, v, d;
scanf("%d%d%d", &u, &v, &d);
g.add(u, v, d), ng.add(v, u, d);
}
sssp(b);
if(!Astar()) puts("No");
return 0;
}
八数码问题
\(h(x)\) 定义为 现在状态的数 与 正确的数 所在格子坐标的曼哈顿距离,这是正确的写法。
如果定义为"有多少数坐标和目标不一样",会出错(不能把 vis[起始]
设为 1),即 oi wiki 给出的写法。
#include <algorithm>
#include <cstdio>
#include <queue>
#include <map>
using namespace std;
const int dx[] = {1,-1,0,0};
const int dy[] = {0,0,1,-1};
struct Matrix {
int a[5][5];
int* operator[](const int &x) {
return a[x];
}
inline bool operator<(const Matrix &rhs) const {
for(int i = 1; i <= 3; ++i)
for(int j = 1; j <= 3; ++j)
if(a[i][j] != rhs.a[i][j]) return a[i][j] < rhs.a[i][j];
return 0;
}
} st, ed;
inline int _abs(int x) {
return x > 0 ? x : -x;
}
inline int h(Matrix x) {
int ret = 0;
int curx = 1, cury = 1;
bool br = 0;
while(curx <= 3) {
br = 0;
for(int i = 1; i <= 3; ++i) {
for(int j = 1; j <= 3; ++j)
if(x[curx][cury] == ed[i][j]) {
ret += _abs(curx - i) + _abs(cury - j);
++cury, br = 1;
break;
}
if(br) break;
}
if(cury == 4)
++curx, cury = 1;
}
return ret;
}
struct T {
Matrix a;
int g;
T() {}
T(Matrix _, int __) {
a = _;
g = __;
}
inline bool operator<(const T &rhs) const {
return g + h(a) > rhs.g + h(rhs.a);
}
};
map<Matrix, bool> vis;
priority_queue<T> q;
int main(void) {
ed[1][1] = 1, ed[1][2] = 2, ed[1][3] = 3;
ed[2][1] = 8, ed[2][2] = 0, ed[2][3] = 4;
ed[3][1] = 7, ed[3][2] = 6, ed[3][3] = 5;
char ch;
for(int i = 1; i <= 3; ++i)
for(int j = 1; j <= 3; ++j)
scanf(" %c", &ch), st[i][j] = ch - '0';
vis[st] = 1;
q.push(T(st, 0));
int fx, fy;
while(!q.empty()) {
T cur = q.top();
q.pop();
if(!h(cur.a)) {
printf("%d\n", cur.g);
return 0;
}
for(int i = 1; i <= 3; ++i)
for(int j = 1; j <= 3; ++j)
if(!cur.a[i][j]) {
fx = i, fy = j;
break;
}
for(int i = 0; i < 4; ++i) {
int nx = fx + dx[i], ny = fy + dy[i];
if(nx < 1 || nx > 3 || ny < 1 || ny > 3)
continue;
swap(cur.a[fx][fy], cur.a[nx][ny]);
if(vis.find(cur.a) != vis.end()) {
swap(cur.a[fx][fy], cur.a[nx][ny]);
continue;
}
vis[cur.a] = 1;
q.push(T(cur.a, cur.g + 1));
swap(cur.a[fx][fy], cur.a[nx][ny]);
}
}
return 0;
}
骑士精神
给一个初始棋盘(只含黑马 白马)问到达目标棋盘最小步数(必须在 15 步之内)。
由于有了步数限制,每次取出 \(f(x)\) 最优的同时也要考虑 \(x\to x'\ f(x')\le 15\),类似于启发式搜索可行性剪枝。因为黑马与黑马,白马与白马之间没有差别,可以定义 \(h(x)\) 为与目标状态的点的差值。
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <queue>
using namespace std;
const int dx[] = {2,-2,1,-1,2,-2,1,-1};
const int dy[] = {-1,-1,-2,-2,1,1,2,2};
struct Matrix {
int a[7][7];
inline int* operator[](const int &x) {
return a[x];
}
inline bool operator==(const Matrix &rhs) const {
for(int i = 1; i <= 5; ++i)
for(int j = 1; j <= 5; ++j)
if(a[i][j] != rhs.a[i][j]) return 0;
return 1;
}
} qwq, ed;
const int kS = 100003;
const int kM = 10007;
const int Pre = 90007;
struct HashTable {
struct HashNode {
Matrix key;
int nxt, pre;
HashNode() {}
HashNode(Matrix _, int __) {
key = _;
nxt = __;
}
} data[kS];
int head[kM], cnt;
inline int f(Matrix x) {
int ret = 0;
for(int i = 1; i <= 5; ++i)
for(int j = 1; j <= 5; ++j)
ret = (ret * 10 + x[i][j]) % kM;
return ret % kM;
}
inline int g(Matrix x) {
int ret = 0;
for(int i = 1; i <= 5; ++i)
for(int j = 1; j <= 5; ++j)
ret = (ret * 233 + x[i][j] * 11) % Pre;
return ret % Pre;
}
inline bool find(Matrix key) {
int pre = g(key);
for(int i = head[f(key)]; i; i = data[i].nxt)
if(data[i].pre == pre && data[i].key == key) return 1;
return 0;
}
inline void insert(Matrix key) {
int pos = f(key);
data[++cnt] = HashNode(key, head[pos]);
data[cnt].pre = g(key);
head[pos] = cnt;
}
inline void clear() {
cnt = 0;
memset(head, 0, sizeof head);
}
} vis;
inline int h(Matrix x) {
int ret = 0;
for(int i = 1; i <= 5; ++i)
for(int j = 1; j <= 5; ++j)
if(x[i][j] != ed[i][j]) ++ret;
return ret;
}
struct T {
Matrix a;
int g;
T() {}
T(Matrix _, int __) {
a = _;
g = __;
}
inline bool operator<(const T &rhs) const {
return g + h(a) > rhs.g + h(rhs.a);
}
};
inline void fuck(Matrix a) {
for(int i = 1; i <= 5; ++i,puts(""))
for(int j = 1; j <= 5; ++j)
printf("%d", a[i][j]);
printf("------------\n");
}
int main(void) {
int Case;
scanf("%d", &Case);
ed[1][1] = 1, ed[1][2] = 1, ed[1][3] = 1, ed[1][4] = 1, ed[1][5] = 1;
ed[2][1] = 0, ed[2][2] = 1, ed[2][3] = 1, ed[2][4] = 1, ed[2][5] = 1;
ed[3][1] = 0, ed[3][2] = 0, ed[3][3] = 2, ed[3][4] = 1, ed[3][5] = 1;
ed[4][1] = 0, ed[4][2] = 0, ed[4][3] = 0, ed[4][4] = 0, ed[4][5] = 1;
ed[5][1] = 0, ed[5][2] = 0, ed[5][3] = 0, ed[5][4] = 0, ed[5][5] = 0;
for(; Case; --Case) {
vis.clear();
int fx, fy;
char ch;
for(int i = 1; i <= 5; ++i)
for(int j = 1; j <= 5; ++j) {
scanf(" %c", &ch);
if(ch == '*') qwq[i][j] = 2;
else qwq[i][j] = ch - '0';
}
priority_queue<T> q;
q.push(T(qwq, 0));
vis.insert(qwq);
bool ok = 0;
while(!q.empty()) {
T cur = q.top();
q.pop();//printf("\t%d\n", cur.g);fuck(cur.a);
if(!h(cur.a)) {
printf("%d\n", cur.g);
ok = 1;
break;
}
bool br = 0;
for(int i = 1; i <= 5; ++i) {
for(int j = 1; j <= 5; ++j)
if(cur.a[i][j] == 2) {
fx = i, fy = j, br = 1;
break;
}
if(br) break;
}
for(int i = 0; i < 8; ++i) {
int nx = fx + dx[i];
int ny = fy + dy[i];
if(nx < 1 || nx > 5 || ny < 1 || ny > 5) continue;
swap(cur.a[nx][ny], cur.a[fx][fy]);
if(vis.find(cur.a)) {
swap(cur.a[nx][ny], cur.a[fx][fy]);
continue;
}
vis.insert(cur.a);
if(cur.g + h(cur.a) <= 15) q.push(T(cur.a, cur.g + 1));
swap(cur.a[nx][ny], cur.a[fx][fy]);
}
}
if(!ok) puts("-1");
}
return 0;
}