回溯法 (Backtracking)
本文介绍回溯法,包括递归型和非递归型,通过下面 3 个例子来解析回溯法:
- 全排列问题
- n 皇后问题
- 三着色问题
回溯法
在许多递归问题当中,我们采取的方法都是穷尽所有的可能,从而找出合法的解。但是在某些情况下,当递归到某一层的时候,根据设置的判断条件,可以 judge 此解是不合法的。在这种情况下,我们就没必要再进行深层次的递归,从而可以提高算法效率。这一类算法我们称为“回溯法”,设置的判断条件称为“剪枝函数”。
回溯法的递归形式:
Input : X = {X1, X2, ..., Xn}
Output: T = (t1, t2, ..., tn)
back-track-rec(int now)
{
for x0 in X
{
T[now] = x0
if (T[0...now] is valid) //如果有效则进行,否则尝试下一个x0
{
if (now == n) //是完整解
{
print(T[1...now]);
return;
}
else if (now < n) //是部分解
{
back-track-rec(now+1);
}
}
}
}
在可计算理论中,有这么一个结论:
所有的递归函数都能转换为迭代,但迭代不一定能转换为递归。
我们知道,C语言当中,函数调用是通过栈来实现的。递归实质是不断进行函数调用,直至参数达到递归的边界。所以,理论上,只要允许使用栈,那么回溯法就可以通过迭代实现。
回溯法的非递归形式:
Input : X = {X1, X2, ..., Xn}
Output: T = (t1, t2, ..., tn)
back-track-itor()
{
int top = 0;
while (top >= 0)
{
while T[top] 没有取尽 X 中的元素
{
T[top] = next(X)
if (check(top) is valid)
{
if (top < N) //部分解
print();
else
top++;
}
}
reset(T[top])
top--
}
}
使用一句话来描述回溯法的思想:对于 T[i]
, 尝试每一个 x0
, 只要 T[i]=x0
有效,则对 T[i+1]
进行尝试,否则回退到 T[i-1]
.
全排列问题
给出一个 N ,输出 N 的全排列。
首先,根据回溯法的递归形式的模板,可以写出下面的代码:
void backTrackRec2(int now)
{
for (int i = 1; i <= N; i++)
{
a[now] = i;
if (check(now))
{
if (now == N - 1)
{
print(N);
return;
}
else
{
backTrackRec2(now + 1);
}
}
}
}
而关键就是如何实现 check
函数去检查是否当前填入的 i
是否有效,全排列的 check
函数很简单:只需要 a[0...now-1]
都与 a[now]
不相等。
bool check(int now)
{
for (int i = 0; i < now; i++)
{
if (a[i] == a[now])
return false;
}
return true;
}
现在分析一下算法复杂度,对于每一个排列,需要对 a[0,...,(N-1)]
都执行一次 check
,那么求解一个序列的复杂度为:
0 + 1 + 2 + ... + (n-1) = n(n-1)/2
现在思考如何把 check
的方法简化:开一个长度为 N+1
的 bool
数组 table[]
,如果数字 k
已经被使用了,那么置 table[k] = true
。复杂度为 O(1)
。
void backTrackRec1(int a[], int N, int now)
{
if (now == N)
{
print(N);
return;
}
for (int x = 1; x <= N; x++)
{
if (table[x] == false)
{
a[now] = x, table[x] = true;
backTrackRec1(a, N, now + 1);
table[x] = false;
}
}
}
最后给出非递归形式的解法,a[]
相当于一个栈,k
是栈顶指针,k++
表示进栈, k--
表示出栈(也是回溯的过程)。
void backTrackItor()
{
int k = 0;
while (k >= 0)
{
while (a[k] < N)
{
a[k]++;
if (check(k))
{
if (k == N - 1)
{
print(N);
break;
}
else
{
k++;
}
}
}
a[k] = 0;
k--;
}
}
n 皇后问题
使用数组 pos[N]
来表示皇后的位置,pos[i] = j
表示第 i
个皇后在位置 (i,j)
。
首先来看递归形式的解法:
void backTrackRec(int now)
{
if (now == N)
{
print();
return;
}
for (int x = 0; x < N; x++)
{
pos[now] = x;
if (check(now))
{
backTrackRec(now + 1);
}
}
}
我们使用 pos
数组来记录位置,已经能保证每个皇后在不同的行上。因此,在 check
函数当中,需要检查新添的皇后是否有同列或者在对角线上(两点斜率为 1
)的情况。
bool check(int index)
{
for (int i = 0; i < index; i++)
{
if (pos[i] == pos[index] || abs(i - index) == abs(pos[i] - pos[index]))
return false;
}
return true;
}
再来看非递归的解法:
void backTrackItor()
{
int top = 0;
while (top >= 0)
{
while (pos[top] < N)
{
pos[top]++;
if (check(top))
{
if (top == N-1)
{
print();
break;
}
else
{
top++;
}
}
}
pos[top--] = 0;
}
}
本质上 n 皇后问题还是在做全排列的枚举,但是因为 check
函数的不同,实际上空间复杂度要小一些。例如当出现:「1 2
」 这种情况,就会被剪枝函数 check
裁去,不再进行深一层的搜索。
三着色问题
三着色问题是指:给出一个无向图 G=(V,E), 使用三种不同的颜色为 G 中的每一个顶点着色,使得没有 2 个相邻的点具有相同的颜色。
首先,我们使用如下的数据结构:
map<int, vector<int>> graph; //图的邻接链表表示
int v, e; //点数,边数
int table[VMAX]; //table[i]=0/1/2, 表示点 i 涂上颜色 R/G/B
很自然的想法,我们会穷举每一个颜色序列,找出合法的解,假设有 3 个顶点,那么自然会这样尝试:
0 0 0
0 0 1
0 0 2
...
但是,这样的穷举并不是想要的结果,因为尝试的过程中没有加入 “没有 2 个相邻的点具有相同的颜色” 这样的判断。
还是直接套回溯法的模板:
void colorRec(int now)
{
for (int i = 0; i < NCOLOR; i++)
{
table[now] = i;
if (check(now))
{
if (now == v - 1) //完整解
{
print(v);
countRec++;
//不应有 return;
}
else
{
colorRec(now + 1);
}
}
}
table[now] = -1;
}
void colorItor()
{
int top = 0;
while (top >= 0)
{
while (table[top] < (NCOLOR - 1))
{
table[top]++;
if (check(top))
{
if (top == v - 1)
{
print(v);
countItor++;
// 不应有 break;
}
else
{
top++;
}
}
}
table[top--] = -1;
}
}
注意上面两处的「不应有」,这是与全排列和 n 皇后有所区别的地方。为什么呢?
假设现有 4 个顶点:
A-----B
|
C-----D
一个合法的着色序列为:
0 1 2 0
如果对应的地方有 break
或者 return
,那么上述序列就会回溯到「0 1 2」
这个序列,但是实际上,在上面序列的基础上继续搜索,可以找到:
0 1 2 1
这也是一个合法的着色序列,如果加入 break
或 return
,这种情况就被忽略了。
附录
3着色代码
#include <cstring>
#include <iostream>
#include <map>
#include <vector>
#define NCOLOR 3
#define VMAX 100
#define EMAX 200
using namespace std;
map<int, vector<int>> graph;
int v, e;
int table[VMAX]; //table[i]=R/G/B, 表示点 i 涂上颜色 R/G/B
int countRec = 0, countItor = 0;
bool check(int now)
{
for (int x : graph[now])
{
if (table[x] != -1 && table[x] == table[now])
return false;
}
return true;
}
void print(int len)
{
cout << "Point: ";
for (int i = 0; i < len; i++)
{
cout << i << ' ';
}
cout << endl;
cout << "Color: ";
for (int i = 0; i < len; i++)
{
cout << table[i] << ' ';
}
cout << "\n"
<< endl;
}
void colorRec(int now)
{
for (int i = 0; i < NCOLOR; i++)
{
table[now] = i;
if (check(now))
{
if (now == v - 1) //完整解
{
print(v);
countRec++;
//不应有 return;
}
else
{
colorRec(now + 1);
}
}
}
table[now] = -1;
}
void colorItor()
{
int top = 0;
while (top >= 0)
{
while (table[top] < (NCOLOR - 1))
{
table[top]++;
if (check(top))
{
if (top == v - 1)
{
print(v);
countItor++;
// 不应有 break;
}
else
{
top++;
}
}
}
table[top--] = -1;
}
}
int main()
{
memset(table, -1, sizeof(table));
cin >> v >> e;
int a, b;
for (int i = 0; i < e; i++)
{
cin >> a >> b;
graph[a].push_back(b);
graph[b].push_back(a);
}
// colorRec(0);
memset(table, -1, sizeof(table));
colorItor();
cout << countRec << " " << countItor << endl;
}
/*
Sample1:
5 7
0 1
0 2
1 3
1 4
2 3
2 4
3 4
*/
全排列代码
#include <iostream>
#include <cstring>
#define MAXN 20
using namespace std;
int a[MAXN] = {0};
bool table[MAXN] = {0};
int N = 0;
void print(int n)
{
for (int i = 0; i < n; i++)
{
cout << a[i] << ' ';
}
cout << endl;
}
bool check(int now)
{
for (int i = 0; i < now; i++)
{
if (a[i] == a[now])
return false;
}
return true;
}
void backTrackRec1(int a[], int N, int now)
{
if (now == N)
{
print(N);
return;
}
for (int x = 1; x <= N; x++)
{
if (table[x] == false)
{
a[now] = x, table[x] = true;
backTrackRec1(a, N, now + 1);
table[x] = false;
}
}
}
void backTrackRec2(int now)
{
for (int i = 1; i <= N; i++)
{
a[now] = i;
if (check(now))
{
if (now == N - 1)
{
print(N);
return;
}
else
{
backTrackRec2(now + 1);
}
}
}
}
void backTrackItor()
{
int k = 0;
while (k >= 0)
{
while (a[k] < N)
{
a[k]++;
if (check(k))
{
if (k == N - 1)
{
print(N);
break;
}
else
{
k++;
}
}
}
a[k] = 0;
k--;
}
}
int main()
{
N = 3;
for (int i = 1; i <= N; i++)
{
a[i - 1] = 0;
}
// backTrackRec1(a, N, 0);
// backTrackRec2(0);
backTrackItor();
}
n皇后代码
#include <iostream>
#define cout std::cout
#define endl std::endl
#define N 4
int count = 0;
int pos[N] = {0};
void print()
{
count++;
for (int i = 0; i < N; i++)
{
int r = i;
int c = pos[i];
for (int i = 0; i < c; i++)
cout << "* ";
cout << "Q ";
for (int i = c + 1; i < N; i++)
cout << "* ";
cout << endl;
}
cout << endl;
}
bool check(int index)
{
for (int i = 0; i < index; i++)
{
if (pos[i] == pos[index] || abs(i - index) == abs(pos[i] - pos[index]))
return false;
}
return true;
}
void backTrackRec(int now)
{
if (now == N)
{
print();
return;
}
for (int x = 0; x < N; x++)
{
pos[now] = x;
if (check(now))
{
backTrackRec(now + 1);
}
}
}
void backTrackItor()
{
memset(pos, -1, sizeof(pos));
int top = 0;
while (top >= 0)
{
while (pos[top] < N)
{
if (++pos[top] >= N) break;
if (check(top))
{
if (top == N-1)
{
print();
break;
}
else
{
top++;
}
}
}
pos[top--] = -1;
}
}
int main()
{
// backTrackRec(0);
backTrackItor();
cout << count << endl;
}