递归与分治
递归
递归:直接或间接(A:这件事找 B;B:这件事找 A)地用到自己。
如何定义正整数?正整数是 \(1,2,3,\cdots\) 这些数。这个定义不是那么的“严密”,换一种方式:
- \(1\) 是正整数
- 如果 \(n\) 是正整数,\(n+1\) 也是正整数
这种定义就是递归式的:在“正整数”还没有定义时,就用到了“正整数”的定义。递归式定义能让定义简洁而严密。
例题:P5739 [深基7.例7] 计算阶乘
数学函数也可以递归定义,阶乘函数 \(f(n)=n!\) 可以定义为:
\( \ f(n) = \begin{cases} 1 & \quad n=0 \\ f(n-1) \times n, &\quad n \ge 1\\ \end{cases} \ \)
#include <cstdio>
int f(int n) {
return n == 0 ? 1 : f(n - 1) * n;
}
int main()
{
int n; scanf("%d", &n);
printf("%d\n", f(n));
return 0;
}
C/C++ 语言支持递归,即函数可以直接或间接地调用自己,但要注意为递归函数编写终止条件,否则将会产生无限递归。
习题:B2142 求 1+2+3+...+N 的值
解题思路
与阶乘类似,\(1+2+ \cdots + n = (1 + 2 + \cdots + (n-1)) + n\),因而若 \(f(n)\) 表示累加函数,则对于 \(n>0\) 时有 \(f(n)=f(n-1)+n\)。
#include <cstdio>
int f(int n) {
return n == 0 ? 0 : f(n - 1) + n;
}
int main()
{
int n; scanf("%d", &n);
printf("%d\n", f(n));
return 0;
}
例题:P1464 Function
分析:直接按照题意实现这个函数并不难,但是提交上去无法通过本题,程序的运行效率很低导致超时。这是因为计算过程中可能会做很多的无用功,比如 \(w(2,3,4)\) 的计算和 \(w(2,3,3), w(2,2,3)\) 有关,而 \(w(2,3,3)\) 和 \(w(2,2,3)\) 的计算又都和 \(w(1,2,3)\) 有关,也就是说,对于 \(w(1,2,3)\) 这样计算结果固定的情况,可能会被调用多次,当 \(a,b,c\) 更大的时候这样的情况还会更严重,而这样的重复计算其实造成了时间上的浪费。
本题真正需要计算的是当 \(a,b,c\) 在 \([1,20]\) 的时候,实际上最多只有 \(20 \times 20 \times 20\) 种情况需要展开递归计算。可以定义一个数组 vis
,其中每一项 vis[i][j][k]
表示 \(w(i,j,k)\) 是否曾经计算过,初始值为 false
,表示未被计算过。当需要展开递归计算时,在最终计算完成的时候将结果存入一个答案数组并将对应的 vis
值更新为 true
。这样下一次如果遇到同样的计算问题时,因为 vis
中对应的值为 true
,就可以直接从答案数组中调出之前计算过的结果而无需重新展开计算。
每种情况最多只计算一次,一旦计算完成就会被存下来,便于日后使用。这样的思想称为“记忆化”。
#include <cstdio>
using ll = long long;
const int N = 25;
bool vis[N][N][N]; // 记录某状态是否被计算过
ll ans[N][N][N]; // 记录某状态下的计算结果
ll w(ll a, ll b, ll c) {
if (a <= 0 || b <= 0 || c <= 0) return 1;
if (a > 20 || b > 20 || c > 20) return w(20, 20, 20);
if (vis[a][b][c]) return ans[a][b][c]; // 如果曾经计算过可以直接返回存下来的结果
// 需要展开计算则在计算完成后更新相应状态及计算结果
if (a < b && b < c) {
ans[a][b][c] = w(a, b, c-1) + w(a, b-1, c-1) - w(a, b-1, c);
vis[a][b][c] = true;
return ans[a][b][c];
}
ans[a][b][c] = w(a-1, b, c) + w(a-1, b-1, c) + w(a-1, b, c-1) - w(a-1, b-1, c-1);
vis[a][b][c] = true;
return ans[a][b][c];
}
int main()
{
while (true) {
ll a, b, c; scanf("%lld%lld%lld", &a, &b, &c);
if (a == -1 && b == -1 && c == -1) break;
printf("w(%lld, %lld, %lld) = %lld\n", a, b, c, w(a, b, c));
}
return 0;
}
例题:P1928 外星密码
分析:如果只有一层方括号,那么只需要找到方括号,就可以提取出重复次数,然后将重复部分按次数复制若干份拼接起来即可。如果方括号的“重复部分”里还有方括号呢?用同样的方式展开即可。可以发现,这个机制和递归非常吻合,因此本题适合用递归的方式来实现。
#include <iostream>
#include <string>
using std::string;
using std::cin;
using std::cout;
string s;
int len, idx; // idx为全局变量,用来控制整个字符串的处理进度
string solve() {
string ret, tmp;
int rep = 0;
// 遇到']'说明本层压缩串需要重复若干份作为解压缩结果返回给上一层
while (idx < len && s[idx] != ']') {
if (s[idx] == '[') { // 遇到'['说明需要往里展开一层压缩
idx++;
tmp += solve();
} else if (s[idx] >= '0' && s[idx] <= '9') { // 更新重复次数
rep = rep * 10 + s[idx] - '0';
idx++;
} else {
tmp += s[idx];
idx++;
}
}
idx++;
ret += tmp;
// 如果rep不等于0,说明这一层属于需要重复拼接展开的
for (int i = 1; i <= rep - 1; i++) ret += tmp;
return ret;
}
int main()
{
cin >> s;
len = s.size();
cout << solve() << "\n";
return 0;
}
如果能将一个大的任务分解成若干规模较小的任务,而且这些任务的形式与结构和原问题一致,就可以考虑使用递归。当问题规模足够小或者达到了边界条件就要停止递归。分解完问题后还要将这些规模小的任务的处理结果合并,最后逐级上报,解决最大规模的问题。
分治
如果想知道我国的人口数量,就需要进行人口普查。让每一个省份都去统计本省有多少人,然后将各省人口累加起来,就可以获得全国的人口数量。而要想知道某一个省的人口数量,可以让省里的每一个城市统计本市有多少人,然后将各市人口累加起来,就可以获得这个省的人口数量……以此类推,层层细分,最后统计一个村子或者一个小区有多少人,这个任务就足够简单了。把一个复杂的问题细分成若干结构相同但规模更小的子问题,然后将每个子问题的解合并起来,就得到了复杂问题的解,这就是分治策略。
P5461 赦免战俘
#include <cstdio>
const int N = 1050;
int a[N][N];
// 左上角坐标(x,y),边长为len的正方形
void solve(int x, int y, int len) {
// 先考虑边界条件
if (len==1) {
a[x][y]=1;
return;
}
// 拆分问题
// 左上角全为0(相当于不用处理)
// 继续用同样的方式处理右上,左下,右下
solve(x,y+len/2,len/2);
solve(x+len/2,y,len/2);
solve(x+len/2,y+len/2,len/2);
}
int main()
{
int n; scanf("%d",&n);
// n = (1<<n);
int len=1;
for (int i=1;i<=n;i++) len*=2;
solve(1,1,len);
for (int i=1;i<=len;i++) {
for (int j=1;j<=len;j++) {
printf("%d ",a[i][j]);
}
printf("\n");
}
return 0;
}
P1228 地毯填补问题
#include <cstdio>
int px, py;
int judge(int xx, int yy, int x, int y, int n) { // 判断残缺的块在哪个分区
if (xx < x + n / 2) return yy < y + n / 2 ? 1 : 2; // 左上/右上
return yy < y + n / 2 ? 3 : 4; // 左下/右下
}
void solve(int n, int x, int y, int miss, int xx, int yy) {
if (n == 1) return;
n = n / 2;
if (miss == 1) {
printf("%d %d %d\n", x + n, y + n, 1);
solve(n, x, y, judge(xx, yy, x, y, n), xx, yy);
solve(n, x, y + n, 3, x + n - 1, y + n);
solve(n, x + n, y, 2, x + n, y + n - 1);
solve(n, x + n, y + n, 1, x + n, y + n);
} else if (miss == 2) {
printf("%d %d %d\n", x + n, y + n - 1, 2);
solve(n, x, y, 4, x + n - 1, y + n - 1);
solve(n, x, y + n, judge(xx, yy, x, y + n, n), xx, yy);
solve(n, x + n, y, 2, x + n, y + n - 1);
solve(n, x + n, y + n, 1, x + n, y + n);
} else if (miss == 3) {
printf("%d %d %d\n", x + n - 1, y + n, 3);
solve(n, x, y, 4, x + n - 1, y + n - 1);
solve(n, x, y + n, 3, x + n - 1, y + n);
solve(n, x + n, y, judge(xx, yy, x + n, y, n), xx, yy);
solve(n, x + n, y + n, 1, x + n, y + n);
} else {
printf("%d %d %d\n", x + n - 1, y + n - 1, 4);
solve(n, x, y, 4, x + n - 1, y + n - 1);
solve(n, x, y + n, 3, x + n - 1, y + n);
solve(n, x + n, y, 2, x + n, y + n - 1);
solve(n, x + n, y + n, judge(xx, yy, x + n, y + n, n), xx, yy);
}
}
int main()
{
int k;
scanf("%d%d%d", &k, &px, &py);
int len = 1;
for (int i = 1; i <= k; i++) len *= 2;
solve(len, 1, 1, judge(px, py, 1, 1, len), px, py);
return 0;
}
归并排序
例题:P1177 【模板】排序
介绍一种新的排序算法——归并排序。要理解归并排序,首先要理解归并。考虑这样一个问题:给定两个有序的序列 \(a,b\),把两个序列合并成一个序列,使得合并出的这个序列是有序的。
算法的过程很简单,维护两个位置 \(i\) 和 \(j\),代表当前考虑 \(a\) 数组的第 \(i\) 个元素与 \(b\) 数组的第 \(j\) 个元素。如果 \(a_i \le b_j\),则在答案数组添加一个 \(a_i\),同时 \(i\) 向后移动。如果 \(a_i > b_j\),则在答案数组添加一个 \(b_j\),同时 \(j\) 向后移动。注意到,如果 \(a,b\) 两个数组中有一个被合并完了,可以直接把另一个数组剩下的部分接到答案数组最后面。
比如有两个有序数组 \(a=[1,3,7,8], b=[2,4,6,9]\),对这两个数组进行归并:
有了归并算法之后,要对一个长度为 \(n\) 的序列进行排序,可以考虑采用分治的思想来解决:如果 \(n=1\),这个序列自然是有序的,所以不用进行排序——这就是可以直接解决的子问题。否则,将序列分为两个长 \(\frac{n}{2}\) 的子序列,对这两个子序列分别递归地进行排序——这是把一个复杂的问题转换为若干个简单一些的问题,然后递归下去解决这些更简答的问题。
当两个子序列有序后,对这两个子序列进行归并,使当前这个长度为 \(n\) 的序列有序——这就是当每个子问题都处理完之后,合并子问题的答案得到原问题的答案。
归并排序的时间复杂度为 \(T(n)=2T(\frac{n}{2})+O(n)=O(n \log n)\)。
推导
#include <cstdio>
const int N = 1e5 + 5;
int a[N], tmp[N]; // tmp是合并时用的临时数组
void mergesort(int l, int r) { // 实现对a[l]~a[r]完成排序
if (l==r) { // 只剩一个数,无需排序,直接返回
return;
}
int mid=(l+r)/2; // a[l]~a[mid] a[mid+1]~a[r]
mergesort(l,mid); mergesort(mid+1,r); // 递归到更小的子问题
// 上面这两个递归调用返回之后意味着左半边和右半边内部已经有序
// 接下来要解决合并的问题
// a[l]~a[mid] a[mid+1]~a[r]
// 先合并到 tmp[l]~tmp[r]
// 最后再搬回 a
int i=l, j=mid+1; // 两部分的合并进度
int k=l; // 下一个数据合并到tmp的什么位置
while (i<=mid && j<=r) {
if (a[i] <= a[j]) {
tmp[k]=a[i]; i++;
} else {
tmp[k]=a[j]; j++;
}
k++;
}
// 上面循环结束时必然是左右半区的其中一个已经合并完成
// 另一个必然还剩下最后一段没有合并进去
while (i<=mid) {
tmp[k]=a[i]; i++; k++;
}
while (j<=r) {
tmp[k]=a[j]; j++; k++;
}
// 此时tmp[l]~tmp[r]已经合并完成,搬回原数组a
for (int i=l;i<=r;i++) a[i]=tmp[i];
}
int main()
{
int n; scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
mergesort(1,n);
for (int i=1;i<=n;i++) printf("%d ",a[i]);
return 0;
}
例题:P1908 逆序对
对于给定的一段正整数序列 \(a\),逆序对是序列中 \(a_i>a_j\) 且 \(i<j\) 的有序对。求一个长度为 \(n\) 的序列的逆序对个数,其中 \(1 \le n \le 5 \times 10^5\)。
分析:对于这个问题,可以在归并排序的过程中同时求出序列的逆序对数。
如果 \(n=1\),这个序列的逆序对个数自然是 \(0\)——这就是可以直接解决的子问题。否则,将序列分为两个长度为 \(\frac{n}{2}\) 的子序列,对这两个子序列分别递归地求出其内部的逆序对——这是把一个复杂的问题转换为若干个简单一些的问题,然后递归下去解决简单一些的问题。
当递归计算了两个子序列内部的逆序对数后,考虑怎么合并这两个子序列。可以发现逆序对还有一种来源,前一个序列中某个元素和后一个序列中某个元素所构成的逆序对,因此还要计算这部分的个数——这就是当每个子问题都处理完之后,合并子问题的答案得到原问题的答案。
那么如何算这种一前一后的情况呢?由于在序列位置中,前一个子序列中的元素一定在后一个子序列中的元素的前面,所以逆序对的 \(i<j\) 已经自然满足了,只需要再考虑 \(a_i>a_j\)。
回顾归并排序的归并过程。对两个有序数组 \(a\) 和 \(b\) 归并的时候,如果某次比较之后在答案数组中放入的元素是 \(b_j\),而和 \(b_j\) 做比较的元素是 \(a_i\),那么一定有 \(a_i, a_{i+1}, \dots\) 均比 \(b_j\) 大,所以在归并排序的过程中,每当在答案数组中放入 \(b_j\) 时,会产生一批逆序对,这样就可以边归并排序边求出整个序列的逆序对数了。
时间复杂度和归并排序一样,为 \(O(n \log n)\)。
#include <cstdio>
using ll = long long;
const int N = 5e5 + 5;
int a[N], tmp[N]; // tmp是合并时用的临时数组
ll mergesort(int l, int r) { // 实现对a[l]~a[r]完成排序
if (l==r) { // 只剩一个数,无需排序
return 0;
}
int mid=(l+r)/2; // a[l]~a[mid] a[mid+1]~a[r]
ll sum=0;
sum += mergesort(l,mid);
sum += mergesort(mid+1,r);
// 上面这两个递归调用返回之后意味着左半边和右半边内部已经有序
// 接下来要解决合并的问题
// a[l]~a[mid] a[mid+1]~a[r]
// 先合并到 tmp[l]~tmp[r]
// 最后再搬回 a
int i=l, j=mid+1; // 两部分的合并进度
int k=l; // 下一个数据合并到tmp的什么位置
while (i<=mid && j<=r) {
if (a[i]<=a[j]) { // a[i]<=a[j]说明这次合并取左边的数
tmp[k]=a[i]; i++;
} else { // a[i]>a[j] 取右边的数
// (a[i],a[j]) 构成了逆序对
// (a[i+1,...mid],a[j]) 都构成了逆序对
sum+=(mid-i+1);
tmp[k]=a[j]; j++;
}
k++;
}
// 上面循环结束时必然是左右半区的其中一个已经合并完成
// 另一个必然还剩下最后一段没有合并进去
while (i<=mid) {
tmp[k]=a[i]; i++; k++;
}
while (j<=r) {
tmp[k]=a[j]; j++; k++;
}
// 此时tmp[l]~tmp[r]已经合并完成,搬回原数组a
for (int i=l;i<=r;i++) a[i]=tmp[i];
return sum;
}
int main()
{
int n; scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
ll ans=mergesort(1,n);
printf("%lld\n",ans);
return 0;
}
快速排序
快速排序的思想是,找一个基准元素,想办法把数组进行划分,比基准元素小的元素放在它的左边,比基准元素大的元素放在它的右边。左右两边各自排序,排完以后整个数组就是有序的了,不需要再合并。显然快速排序也是基于分治思想解决问题的算法。
具体来说,假设数需要排序的数组为 \(a\)(为了方便叙述,先假设没有重复数字),选中 \(a_1\) 作为基准元素,其后的每一个元素与之比大小。尽可能将小于 \(a_1\) 的元素放在前面,大于 \(a_1\) 的数放在后面,\(a_1\) 放在中间。实现的方法是,设置两个变量 i
和 j
,i
指向左边元素后面的第一个元素,j
指向最后一个元素。如果 a[i]<a[1]
,i
就指向下一个元素,即 i++
;若 a[i]>a[1]
,i
停在当前位置。另一侧,如果 a[j]>a[1]
,j
指向“下一个”元素,即 j--
;若 a[j]<a[1]
,则 j
停在当前位置。此时,a[i]
比较大,应该向右换,而 a[j]
比较小,应该向左换,正好可以把 i
和 j
位置上的元素交换。交换后,便可以得到 a[i]<a[1]
并且 a[j]>a[1]
。之后 i
加一,j
减一,继续以上操作,直至 i,j
相遇并错位。最后一步,将 a[j]
与 a[1]
进行交换,a[j]
成为整个数组的基准元素。整个数组被分成两个部分,左半部分 a[1]~a[j-1]
都小于 a[j]
,右半部分 a[j+1]~a[n]
都大于 a[j]
。这个划分过程的时间复杂度是 \(O(n)\)。
例如,原始数组为 \([4,2,1,5,7,8,0,9,3,6]\),一开始选择数组第一个元素 \(4\) 作为基准元素,\(i\) 指向基准元素后面的 \(2\),\(j\) 指向最后一个元素 \(6\)。
首先是将 \(i\) 向后移动,越过比 \(a_1\) 小的 \(2\) 和 \(1\),最终 \(i\) 停在 \(5\) 上。
接下来将 \(j\) 向左移动,\(j\) 跳过比 \(4\) 大的 \(6\),停留在比 \(4\) 小的 \(3\) 上。此时 \(i\) 指向的数比 \(4\) 大,想换到右边去,\(j\) 指向的数比 \(4\) 小,想换到左边去,正好可以交换位置。交换完以后,\(i\) 再右移一步停在 \(7\) 上,\(j\) 向左移一步停在 \(9\) 上。
继续进行上述操作,\(i\) 停在 \(7\) 上,\(j\) 停在 \(0\) 上,\(a_i\) 和 \(a_j\) 互换位置,\(i\) 向右,\(j\) 向左。接下来,\(i\) 停在 \(8\) 上,\(j\) 向左走到 \(0\) 上。此时 \(i\) 和 \(j\) 已经错位了,将基准元素 \(a_1\) 与 \(a_j\) 交换,交换后基准元素到了 \(j\) 的位置,\(j\) 左边的元素都比 \(a_j\) 小,\(j\) 右边的元素都比 \(a_j\) 大,划分完成。并且在此次操作中,\(i\) 和 \(j\) 一起扫过了数组中所有元素,所以时间复杂度是 \(O(n)\)。
划分操作完成后,左右两半的数据规模大致变为原来的一半,两部分分别进行递归排序,各自排好序之后,整个数组的顺序就排好了。需要注意的是,和归并排序一样,整个排序过程是递归进行的,所以每次划分需要指定当前要处理的数组区间范围。
// l和r表示区间的起点和终点
int pivot = a[l]; // 选择a[l]作为基准元素
int i = l + 1, j = r; // i指向基准后面的第一个元素,j指向最后一个元素
while (true) {
while (i <= r && a[i] < pivot) i++; // i越过小于基准元素的数
while (j >= l && a[j] > pivot) j--; // j越过大于基准元素的数
if (i >= j) break; // 当i和j错位时,停止
swap(a[i], a[j]); i++; j--; // 交换a[i]和a[j],i向右走,j向左走
}
swap(a[j], a[l]); // 交换a[j]和a[l]
前面的算法分析中没有考虑数组中有相同元素的情况,其实即使有相同的元素,算法的正确性依然没有问题。在 \(i\) 和 \(j\) 两个变量移动的过程中,当遇到和基准元素相等的元素时,停在当前元素上,一会儿依旧做交换。最终划分的效果是,所有元素中,不大于基准元素的会被换到基准元素的左边,不小于基准元素的会被换到基准元素的右边。
在快速排序过程中,每一次划分都是在一个区间范围内,用基准元素作为比对标准,比它小的元素都放在它的左边,比它大的元素都放在它的右边。各自递归左右两部分,直至区间剩下一个元素。划分操作的时间复杂度正比于区间长度,那么递归会有多少层呢?
考虑 \(n\) 个数据进行排序,对于第 \(1\) 层递归调用,\(l=1,r=n\)。不考虑下一层递归,单独看这一层的划分操作,时间复杂度为 \(O(n)\),并且把数组分成了两部分。这两部分不一定是大小相等的,不过如果数据比较随机,可以认为期望大致划分到中间位置附近,下一层递归调用就是分别调用左边的一半和右边的一半。在第 \(2\) 层递归调用中,左边一半的划分需要的计算量和右边一半需要的计算量拼起来正好还是 \(O(n)\)。接下来递归第 \(3\) 层,同理,第 \(3\) 层有 \(4\) 次递归调用,总的计算量还是 \(O(n)\)。以此类推,每层递归调用的计算量都是 \(O(n)\),那么一共有多少层呢?如果每一次递归调用的划分都足够均匀,则层数大约是 \(\log_2^n\) 的,总的时间复杂度是 \(O(n \log n)\)。
如果待排序的数组不是散乱的,而是比较均匀的(极端情况,考虑一个已经排好序的数组),那么快速排序的效率是更高还是更低呢?直觉上,如果一个数组已经排好序了,再拿来排序,应该是不花时间的,或者是速度非常快的。但实际上,考虑对已经排好序的数组进行一次划分,由于基准元素一开始就在最左边,而区间里面没有比基准元素更小的元素了,划分完成以后基准元素还是在最左边。在下次递归调用时,基准元素左边没有元素,基准元素右边全部调用到下一层,下一层的区间长度只少了 \(1\)。如果一开始有 \(n\) 个数,总共需要调用 \(n\) 层,每层划分的时间复杂度是 \(O(n)\),最终时间复杂度反而达到了 \(O(n^2)\)。
一个简单的优化是,不选取第一个数作为基准元素,而是选取当前区间中间的元素。实际编程的时候,可以将当前区间中间的元素与区间第一个元素进行交换,之后的代码就不用改动了。这样即使输入的数据是有序的,依旧可以划分得比较均匀。
例题:P1177 【模板】排序
参考代码
#include <cstdio>
#include <algorithm>
using std::swap;
const int N = 1e5 + 5;
int a[N];
void quicksort(int l, int r) {
int mid = (l + r) / 2;
swap(a[l], a[mid]); // 把中间元素与第一个元素交换
int pivot = a[l]; // 选择a[l]作为基准元素
int i = l + 1, j = r; // i指向基准后面的第一个元素,j指向最后一个元素
while (true) {
while (i <= r && a[i] < pivot) i++; // i越过小于基准元素的数
while (j >= l && a[j] > pivot) j--; // j越过大于基准元素的数
if (i >= j) break; // 当i和j错位时,停止
swap(a[i], a[j]); i++; j--; // 交换a[i]和a[j],i向右走,j向左走
}
swap(a[j], a[l]); // 交换a[j]和a[l]
if (l < j - 1) quicksort(l, j - 1); // 如果左边有不止一个数,递归对左边排序
if (j + 1 < r) quicksort(j + 1, r); // 如果右边有不止一个数,递归对右边排序
}
int main()
{
int n; scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
quicksort(1, n);
for (int i = 1; i <= n; i++) printf("%d ", a[i]);
return 0;
}