搜索剪枝练习
搜索剪枝练习
搜索是一类很暴力的做法,往往时间复杂度都是指数级别的,大部分时候都无法作为正解使用。不过可以通过一些剪枝技巧,减小搜索规模,加快程序运行速度。
P1025 [NOIP2001 提高组] 数的划分
题目描述
将整数 \(n\) 分成 \(k\) 份,且每份不能为空,任意两个方案不相同(不考虑顺序)。
例如:\(n=7\),\(k=3\),下面三种分法被认为是相同的。
\(1,1,5\);
\(1,5,1\);
\(5,1,1\).
问有多少种不同的分法。
输入格式
\(n,k\) (\(6<n \le 200\),\(2 \le k \le 6\))
输出格式
\(1\) 个整数,即不同的分法。
样例 #1
样例输入 #1
7 3
样例输出 #1
4
提示
四种分法为:
\(1,1,5\);
\(1,2,4\);
\(1,3,3\);
\(2,2,3\).
【题目来源】
NOIP 2001 提高组第二题
Solution
看到数据范围,既然分组只有 \(k\le 6\) 组,那么直接暴搜就行了。
不过暴搜可能会超时,所以需要加一些剪枝优化。
上下界优化:假设目前的分组方案是 \(a\),当前数是 \(a_i\),那么下界保证 \(a_i\ge a_{i-1}\),上界根据这个条件设置为 \(\displaystyle\frac{n-sum}{k-i+1}\) 即可。
这样其实就可以 AC 了,不过还可以再加一个小的剪枝。如果搜到了第 \(k\) 组那么直接统计入答案就可以了,因为最后一组一定就是分剩下的数,而不需要再去枚举了。
#include<bits/stdc++.h>
using namespace std;
int n, k, ans = 0;
void dfs(int x, int sum, int last) {
if (x == k) return ans++, void();
for (int i = last, maxn = (n - sum) / (k - x + 1); i <= maxn; i++) dfs(x + 1, sum + i, i);
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> k;
dfs(1, 0, 1);
cout << ans << '\n';
return 0;
}
P1120 小木棍
题目描述
乔治有一些同样长的小木棍,他把这些木棍随意砍成几段,直到每段的长都不超过 \(50\)。
现在,他想把小木棍拼接成原来的样子,但是却忘记了自己开始时有多少根木棍和它们的长度。
给出每段小木棍的长度,编程帮他找出原始木棍的最小可能长度。
输入格式
第一行是一个整数 \(n\),表示小木棍的个数。
第二行有 \(n\) 个整数,表示各个木棍的长度 \(a_i\)。
输出格式
输出一行一个整数表示答案。
样例 #1
样例输入 #1
9
5 2 1 5 2 1 5 2 1
样例输出 #1
6
提示
对于全部测试点,\(1 \leq n \leq 65\),\(1 \leq a_i \leq 50\)。
Solution
这道题的剪枝可谓是丧心病狂。
暴搜的思路很好想到,但是剪枝优化的思路确实很难想。
-
将木棍从大到小排序,从最大的开始选择(优化搜索顺序)
这一点其实很好理解,因为大的木棍肯定没有小的木棍灵活,也就是说小的木棍可以很多根一起组合拼到一个地方,也可以拆散单个去补大的木棍的空。而大的木棍可以放的地方肯定没有小得多,所以将大的木棍放在最开始可以减少很多无用状态的搜索。
-
保证每次搜索的长度是递减的(排除等效冗余)
对于两个木棍 \(x,y\),先拼 \(x\) 后拼 \(y\) 和先拼 \(y\) 后拼 \(x\) 显然是相同的,所以不如人为规定 \(l_x > l_y\)。
-
去除相同长度的木棍(排除等效冗余)
如果一根长度为 \(l\) 的木棍拼进去拼不出来的话,那么其他相同长度的木棍拼进去也一定拼不出来。
-
拼入第一根木棍就失败或最后一根刚好时直接回溯(排除等效冗余)
如果第一根木棍就失败,那么这根木棍放在任何位置都是不可行的,所以直接回溯。此外如果当前木棍总长为 \(x\),答案为 \(len\),如果当前木棍长度 \(a_i=len-x\) 并且失败了,那么也直接回溯,原因同上。
有了这些剪枝就可以通过这道题了,跑的还算比较快。
#include<bits/stdc++.h>
using namespace std;
int a[105], v[105], n, len, cnt, nxt[105];
bool dfs(int stick, int cab, int last) {
if (stick > cnt) return true;
if (cab == len) return dfs(stick + 1, 0, 1);
for (int i = last; i <= n; i++) {
if (!v[i] && cab + a[i] <= len) {
v[i] = 1;
if (dfs(stick, cab + a[i], i + 1)) return true;
v[i] = 0;
if (cab == 0) return false;
if (cab + a[i] == len) return false;
i = nxt[i];
}
}
return false;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n;
int sum = 0;
for (int i = 1; i <= n; i++) {
cin >> a[i];
sum += a[i];
}
sort(a + 1, a + n + 1);
reverse(a + 1, a + n + 1);
nxt[n] = n;
for (int i = n - 1; i; i--) {
if (a[i] == a[i + 1]) nxt[i] = nxt[i + 1];
else nxt[i] = i;
}
for (len = a[1]; len <= sum / 2; len++) {
if (sum % len != 0) continue;
cnt = sum / len;
if (dfs(1, 0, 1)) break;
}
cout << (len <= sum / 2 ? len : sum) << '\n';
return 0;
}