回溯法 (Backtracking)

本文介绍回溯法,包括递归型和非递归型,通过下面 3 个例子来解析回溯法:

  • 全排列问题
  • n 皇后问题
  • 三着色问题

回溯法

在许多递归问题当中,我们采取的方法都是穷尽所有的可能,从而找出合法的解。但是在某些情况下,当递归到某一层的时候,根据设置的判断条件,可以 judge 此解是不合法的。在这种情况下,我们就没必要再进行深层次的递归,从而可以提高算法效率。这一类算法我们称为“回溯法”,设置的判断条件称为“剪枝函数”。

回溯法的递归形式:

Input : X = {X1, X2, ..., Xn}
Output: T = (t1, t2, ..., tn)

back-track-rec(int now)
{
    for x0 in X
    {
        T[now] = x0
        if (T[0...now] is valid)  //如果有效则进行,否则尝试下一个x0
        {
            if (now == n)  //是完整解
            {
                print(T[1...now]);
                return;
            }
            else if (now < n)  //是部分解
            {
                back-track-rec(now+1);
            }
        }
    }
}

在可计算理论中,有这么一个结论:

所有的递归函数都能转换为迭代,但迭代不一定能转换为递归。

我们知道,C语言当中,函数调用是通过栈来实现的。递归实质是不断进行函数调用,直至参数达到递归的边界。所以,理论上,只要允许使用栈,那么回溯法就可以通过迭代实现。

回溯法的非递归形式:

Input : X = {X1, X2, ..., Xn}
Output: T = (t1, t2, ..., tn)

back-track-itor()
{
    int top = 0;
    while (top >= 0)
    {
        while T[top] 没有取尽 X 中的元素
        {
            T[top] = next(X)
            if (check(top) is valid)
            {
                if (top < N)    //部分解
                    print();
                else
                    top++;
            }
        }
        reset(T[top])
        top--
    }
}

使用一句话来描述回溯法的思想:对于 T[i], 尝试每一个 x0, 只要 T[i]=x0 有效,则对 T[i+1] 进行尝试,否则回退到 T[i-1] .

全排列问题

给出一个 N ,输出 N 的全排列。

首先,根据回溯法的递归形式的模板,可以写出下面的代码:

void backTrackRec2(int now)
{
    for (int i = 1; i <= N; i++)
    {
        a[now] = i;
        if (check(now))
        {
            if (now == N - 1)
            {
                print(N);
                return;
            }
            else
            {
                backTrackRec2(now + 1);
            }
        }
    }
}

而关键就是如何实现 check 函数去检查是否当前填入的 i 是否有效,全排列的 check 函数很简单:只需要 a[0...now-1] 都与 a[now] 不相等。

bool check(int now)
{
    for (int i = 0; i < now; i++)
    {
        if (a[i] == a[now])
            return false;
    }
    return true;
}

现在分析一下算法复杂度,对于每一个排列,需要对 a[0,...,(N-1)] 都执行一次 check,那么求解一个序列的复杂度为:

0 + 1 + 2 + ... + (n-1) = n(n-1)/2

现在思考如何把 check 的方法简化:开一个长度为 N+1bool 数组 table[] ,如果数字 k 已经被使用了,那么置 table[k] = true 。复杂度为 O(1)

void backTrackRec1(int a[], int N, int now)
{
    if (now == N)
    {
        print(N);
        return;
    }
    for (int x = 1; x <= N; x++)
    {
        if (table[x] == false)
        {
            a[now] = x, table[x] = true;
            backTrackRec1(a, N, now + 1);
            table[x] = false;
        }
    }
}

最后给出非递归形式的解法,a[] 相当于一个栈,k 是栈顶指针,k++ 表示进栈, k-- 表示出栈(也是回溯的过程)。

void backTrackItor()
{
    int k = 0;
    while (k >= 0)
    {
        while (a[k] < N)
        {
            a[k]++;
            if (check(k))
            {
                if (k == N - 1)
                {
                    print(N);
                    break;
                }
                else
                {
                    k++;
                }
            }
        }
        a[k] = 0;
        k--;
    }
}

n 皇后问题

使用数组 pos[N] 来表示皇后的位置,pos[i] = j 表示第 i 个皇后在位置 (i,j)

首先来看递归形式的解法:

void backTrackRec(int now)
{
    if (now == N)
    {
        print();
        return;
    }
    for (int x = 0; x < N; x++)
    {
        pos[now] = x;
        if (check(now))
        {
            backTrackRec(now + 1);
        }
    }
}

我们使用 pos 数组来记录位置,已经能保证每个皇后在不同的行上。因此,在 check 函数当中,需要检查新添的皇后是否有同列或者在对角线上(两点斜率为 1 )的情况。

bool check(int index)
{
    for (int i = 0; i < index; i++)
    {
        if (pos[i] == pos[index] || abs(i - index) == abs(pos[i] - pos[index]))
            return false;
    }
    return true;
}

再来看非递归的解法:

void backTrackItor()
{
    int top = 0;
    while (top >= 0)
    {
        while (pos[top] < N)
        {
            pos[top]++;
            if (check(top))
            {
                if (top == N-1)
                {
                    print();
                    break;
                }
                else
                {
                    top++;
                }
                
            }
        }
        pos[top--] = 0;
    }
}

本质上 n 皇后问题还是在做全排列的枚举,但是因为 check 函数的不同,实际上空间复杂度要小一些。例如当出现:「1 2」 这种情况,就会被剪枝函数 check 裁去,不再进行深一层的搜索。

三着色问题

三着色问题是指:给出一个无向图 G=(V,E), 使用三种不同的颜色为 G 中的每一个顶点着色,使得没有 2 个相邻的点具有相同的颜色。

首先,我们使用如下的数据结构:

map<int, vector<int>> graph;  //图的邻接链表表示
int v, e;  //点数,边数
int table[VMAX]; //table[i]=0/1/2, 表示点 i 涂上颜色 R/G/B

很自然的想法,我们会穷举每一个颜色序列,找出合法的解,假设有 3 个顶点,那么自然会这样尝试:

0 0 0
0 0 1
0 0 2
...

但是,这样的穷举并不是想要的结果,因为尝试的过程中没有加入 “没有 2 个相邻的点具有相同的颜色” 这样的判断。

还是直接套回溯法的模板:


void colorRec(int now)
{
    for (int i = 0; i < NCOLOR; i++)
    {
        table[now] = i;
        if (check(now))
        {
            if (now == v - 1) //完整解
            {
                print(v);
                countRec++;
                //不应有 return;
            }
            else
            {
                colorRec(now + 1);
            }
        }
    }
    table[now] = -1;
}
void colorItor()
{
    int top = 0;
    while (top >= 0)
    {
        while (table[top] < (NCOLOR - 1))
        {
            table[top]++;
            if (check(top))
            {
                if (top == v - 1)
                {
                    print(v);
                    countItor++;
                    // 不应有 break;
                }
                else
                {
                    top++;
                }
            }
        }
        table[top--] = -1;
    }
}

注意上面两处的「不应有」,这是与全排列和 n 皇后有所区别的地方。为什么呢?

假设现有 4 个顶点:

A-----B
|     
C-----D

一个合法的着色序列为:

0 1 2 0

如果对应的地方有 break 或者 return,那么上述序列就会回溯到「0 1 2」这个序列,但是实际上,在上面序列的基础上继续搜索,可以找到:

0 1 2 1

这也是一个合法的着色序列,如果加入 breakreturn ,这种情况就被忽略了。

附录

3着色代码

#include <cstring>
#include <iostream>
#include <map>
#include <vector>
#define NCOLOR 3
#define VMAX 100
#define EMAX 200
using namespace std;
map<int, vector<int>> graph;
int v, e;
int table[VMAX]; //table[i]=R/G/B, 表示点 i 涂上颜色 R/G/B
int countRec = 0, countItor = 0;
bool check(int now)
{
    for (int x : graph[now])
    {
        if (table[x] != -1 && table[x] == table[now])
            return false;
    }
    return true;
}
void print(int len)
{
    cout << "Point: ";
    for (int i = 0; i < len; i++)
    {
        cout << i << ' ';
    }
    cout << endl;
    cout << "Color: ";
    for (int i = 0; i < len; i++)
    {
        cout << table[i] << ' ';
    }
    cout << "\n"
         << endl;
}
void colorRec(int now)
{
    for (int i = 0; i < NCOLOR; i++)
    {
        table[now] = i;
        if (check(now))
        {
            if (now == v - 1) //完整解
            {
                print(v);
                countRec++;
                //不应有 return;
            }
            else
            {
                colorRec(now + 1);
            }
        }
    }
    table[now] = -1;
}
void colorItor()
{
    int top = 0;
    while (top >= 0)
    {
        while (table[top] < (NCOLOR - 1))
        {
            table[top]++;
            if (check(top))
            {
                if (top == v - 1)
                {
                    print(v);
                    countItor++;
                    // 不应有 break;
                }
                else
                {
                    top++;
                }
            }
        }
        table[top--] = -1;
    }
}
int main()
{
    memset(table, -1, sizeof(table));
    cin >> v >> e;
    int a, b;
    for (int i = 0; i < e; i++)
    {
        cin >> a >> b;
        graph[a].push_back(b);
        graph[b].push_back(a);
    }
    // colorRec(0);
    memset(table, -1, sizeof(table));
    colorItor();
    cout << countRec << " " << countItor << endl;
}

/*
Sample1:
5 7
0 1
0 2
1 3
1 4
2 3
2 4
3 4
 */

全排列代码

#include <iostream>
#include <cstring>
#define MAXN 20
using namespace std;
int a[MAXN] = {0};
bool table[MAXN] = {0};
int N = 0;
void print(int n)
{
    for (int i = 0; i < n; i++)
    {
        cout << a[i] << ' ';
    }
    cout << endl;
}
bool check(int now)
{
    for (int i = 0; i < now; i++)
    {
        if (a[i] == a[now])
            return false;
    }
    return true;
}
void backTrackRec1(int a[], int N, int now)
{
    if (now == N)
    {
        print(N);
        return;
    }
    for (int x = 1; x <= N; x++)
    {
        if (table[x] == false)
        {
            a[now] = x, table[x] = true;
            backTrackRec1(a, N, now + 1);
            table[x] = false;
        }
    }
}
void backTrackRec2(int now)
{
    for (int i = 1; i <= N; i++)
    {
        a[now] = i;
        if (check(now))
        {
            if (now == N - 1)
            {
                print(N);
                return;
            }
            else
            {
                backTrackRec2(now + 1);
            }
        }
    }
}
void backTrackItor()
{
    int k = 0;
    while (k >= 0)
    {
        while (a[k] < N)
        {
            a[k]++;
            if (check(k))
            {
                if (k == N - 1)
                {
                    print(N);
                    break;
                }
                else
                {
                    k++;
                }
            }
        }
        a[k] = 0;
        k--;
    }
}

int main()
{
    N = 3;
    for (int i = 1; i <= N; i++)
    {
        a[i - 1] = 0;
    }
    // backTrackRec1(a, N, 0);
    // backTrackRec2(0);
    backTrackItor();
}

n皇后代码

#include <iostream>
#define cout std::cout
#define endl std::endl
#define N 4
int count = 0;
int pos[N] = {0};
void print()
{
    count++;
    for (int i = 0; i < N; i++)
    {
        int r = i;
        int c = pos[i];
        for (int i = 0; i < c; i++)
            cout << "* ";
        cout << "Q ";
        for (int i = c + 1; i < N; i++)
            cout << "* ";
        cout << endl;
    }
    cout << endl;
}

bool check(int index)
{
    for (int i = 0; i < index; i++)
    {
        if (pos[i] == pos[index] || abs(i - index) == abs(pos[i] - pos[index]))
            return false;
    }
    return true;
}

void backTrackRec(int now)
{
    if (now == N)
    {
        print();
        return;
    }
    for (int x = 0; x < N; x++)
    {
        pos[now] = x;
        if (check(now))
        {
            backTrackRec(now + 1);
        }
    }
}
void backTrackItor()
{
    memset(pos, -1, sizeof(pos));
    int top = 0;
    while (top >= 0)
    {
        while (pos[top] < N)
        {
            if (++pos[top] >= N) break;
            if (check(top))
            {
                if (top == N-1)
                {
                    print();
                    break;
                }
                else
                {
                    top++;
                }
                
            }
        }
        pos[top--] = -1;
    }
}
int main()
{
    // backTrackRec(0);
    backTrackItor();
    cout << count << endl;
}

posted @ 2019-09-07 17:04  sinkinben  阅读(4880)  评论(1编辑  收藏  举报