一个不剪枝的程序引发的血案 - 回溯了就请尽量剪枝吧!
回溯和穷举没啥两样,区别只在于回溯通过栈,以空间换取时间,节省的是前面重复的计算。因此在用回溯算法解决问题时,能剪枝就剪枝,一定要跳过不必要的计算过程。
问题:求10次射击共中80环的次数(概率)。即,给定正整数N=10(射击十次),M=10(共十环),求满足x_0+x_1+...+X_{N-1}=80的代数式的个数。
最简单的莫过于递归:
double fun(int n, int sum) { if(n<ceil(sum/10.)) return 0; if(n==1) return 1; double s = 0; for(int i=0; i<=10; i++) s+=fun(n-1, sum-i); return s; }
但是递归往往难以剪枝,而且递归代码的简短带来的是函数调用的代价和栈溢出的风险,有时效率大概比穷举还慢。
第二种方法是回溯,也很简单啦,设置一个大小为N的数组(或者栈),初始时为 0 0 0 0 0 ... 0,然后模拟M+1进制的加法:
double fun_stk(int n, int target) { int c = 0; double ret = 0.; int const M=10; double sum = target; deque<int> s; while(s.size()<n) s.push_back(10); while(1) { if(sum >= target) { ret++; } //加一 s.back()++; if(s.back() > M) { //进位 while(s.size() && s.back()>=M) { sum -= M; s.pop_back(); } if(s.empty()) break; s.back()++; while(s.size()<n) s.push_back(0); } sum++;
} return ret;
}
对于回溯算法,应该尽可能地剪枝。然而有些算法是没办法剪枝的,比如八皇后问题(eight-queens),比如数独(sudoku)。如果剪枝无力,就该考虑其他方法了,比如数独就存在DFS和BFS的解法。
于是进一步考虑怎么剪枝。举几个例子说:设N=4,target=25,即射击4次得25环。
1 我们不需要从0 0 0 0开始,而应该从第一个sum=25的序列开始:0 5 10 10。求法是从后往前依次放入小于等于M的数,直到sum=25为止。
2 在加法一步一步执行的过程中,一旦sum>=target了,那么继续加一就没有意义了,因为加一势必导致sum>target,于是强制使其在下一次+1时进位,即将最后一位设为M。
3 依据2得到的序列中,势必会有sum<target,既然如此我们也不需要按部就班地加一,因为可以跳跃式地直接加上target-sum,同时把序列补完。
这样的到的代码,和每次中规中矩地加一比起来,通过跳跃式地加法,跳过了中间大部分不需要的步骤:
double fun_stk(int n, int target) { double ret = 0.; int const M=10; double sum = target; deque<int> s; int tmp = ceil(sum/M); while(s.size()<n-tmp) s.push_back(0); s.push_back(sum-(tmp-1)*M); while(s.size()<n) s.push_back(10); while(1) { if(sum >= target) { if(sum==target) ret++; sum+=M-s.back(); s.back() = M; } //加一 s.back()++; if(s.back() > M) { //进位 while(s.size() && s.back()>=M) { sum -= M; s.pop_back(); } if(s.empty()) break; s.back()++; while(s.size()<n) s.push_back(0); } sum++; //剪枝. 直接将s内的数加到下一个sum=target的状态 tmp = target - sum; if(tmp>0) { //将不足的数补到后面 size_t last = n; while(tmp>0 && (--last)>=0) { int rem = M - s[last]; if(rem>tmp) { s[last]+=tmp; sum+=tmp; break; } else { s[last] += rem; sum+=rem; tmp -= rem; } } } } return ret; }
然而这样还没结束。从0 0 0 0到0 5 10 10,实际上也是3可以完成的工作,为了代码的简短,把while前面不必要的运算去掉:
double fun_stk(int n, int target) { double ret = 0.; int const M=10; double sum = 0.; deque<int> s(n, 0); int tmp; while(1) { if(sum >= target) { if(sum==target) ret++; sum+=M-s.back(); s.back() = M; } //加一 s.back()++; if(s.back() > M) { //进位 while(s.size() && s.back()>=M) { sum -= M; s.pop_back(); } if(s.empty()) break; s.back()++; while(s.size()<n) s.push_back(0); } sum++; //剪枝 tmp = target - sum; if(tmp>0) { //将不足的数补到后面 size_t last = n; while(tmp>0 && (--last)>=0) { int rem = M - s[last]; if(rem>tmp) { s[last]+=tmp; sum+=tmp; break; } else { s[last] += rem; sum+=rem; tmp -= rem; } } } } return ret; }
(: 由于deque是一个较为复杂的数据结构,通过RandomIterator去访问它并不是一个好的选择。最初写成deque是为了使算法更像是回溯算法这么回事。此后需将其改成固定长度为N的数组,deque的出入栈操作可以用数组代替完成,提高效率。
第十八式,打完收工~