动态规划
最近学习了动态规划(简称dp),耗光了我的所有脑细胞和头发,特此做一小记纪念一下。
目录(目前已学)
1. 线性dp
2. 区间dp
3. 状压dp
4. 树形dp
5. 数位dp
6. 数据结构的优化
线性dp
背包
讲到线性dp,就不得不讲到它的一个庞大分支,也是一种最广为人知的dp,背包问题。其主要的内容就是假设你有一个容量为\(W\)的背包以及\(n\)个物品,每个物品都有一定的体积和价值,要求求出最多可装价值。(即在物品体积总和小于等于\(W\)时可获得的最大价值,且物品不可分割)。 后来,又延伸出了四种——01背包(每个物品只有一个),完全背包(每个物品有无限个),多重背包(每个物品有一定数量,其实只要拆成多个不同的物品就好了),分组背包(每个物品有它的对应组别,每个组有限定的可取数量)。
Code:
01背包
#include<iostream>
#include<cstdio>
#include<cstring>
#define MAXN 1000
#define MAXM 10000
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
struct GOODS {
int weight, value;
};
GOODS goods[MAXN + 9];
int f[MAXM + 9];
signed main() {
int n = read(), W = read();
for(int i = 1; i <= n; ++i) {
goods[i].weight = read(), goods[i].value = read();
}
f[0] = 0;
for(int i = 1; i <= n; ++i) {
for(int j = W; j >= goods[i].weight; --j) {
f[j] = max(f[j], f[j - goods[i].weight] + goods[i].value);
}
}
int ans = -0x3f3f3f3f;
for(int i = 1; i <= W; ++i) {
if(f[i] > ans) ans = f[i];
}
printf("%d", ans);
return 0;
}
多重背包
#include<iostream>
#include<cstdio>
#include<cstring>
#define MAXN 1000
#define MAXM 10000
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
struct GOODS {
int weight, value, c;
};
GOODS goods[MAXN + 9];
int f[MAXM + 9];
signed main() {
int n = read(), W = read();
for(int i = 1; i <= n; ++i) {
goods[i].weight = read(), goods[i].value = read(), goods[i].c = read(); //每种物品有c个
}
f[0] = 0;
for(int i = 1; i <= n; ++i) {
for(int k = 1; k <= goods[i].c; ++k) { //将物品拆成c个
for(int j = W; j >= goods[i].weight; --j) {
f[j] = max(f[j], f[j - goods[i].weight] + goods[i].value);
}
}
}
int ans = -0x3f3f3f3f;
for(int i = 1; i <= W; ++i) {
if(f[i] > ans) ans = f[i];
}
printf("%d", ans);
return 0;
}
完全背包
#include<iostream>
#include<cstdio>
#include<cstring>
#define MAXN 1000
#define MAXM 10000
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
struct GOODS {
int weight, value;
};
GOODS goods[MAXN + 9];
int f[MAXM + 9];
signed main() {
int n = read(), W = read();
for(int i = 1; i <= n; ++i) {
goods[i].weight = read(), goods[i].value = read();
}
f[0] = 0;
for(int i = 1; i <= n; ++i) {
for(int j = goods[i].weight; j <= W; --j) { //01背包是反着更新,完全背包正着更新
f[j] = max(f[j], f[j - goods[i].weight] + goods[i].value);
}
}
int ans = -0x3f3f3f3f;
for(int i = 1; i <= W; ++i) {
if(f[i] > ans) ans = f[i];
}
printf("%d", ans);
return 0;
}
分组背包
#include<cstdio> //每组选一个为例子
#include<algorithm>
#include<cstring>
#define MAXN 1000
#define MAXM 1000
#define MAXK 100
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
struct GOODS {
int weight, value;
};
GOODS goods[MAXK + 9][MAXN + 9];
int sum[MAXK + 9];
struct IN {
int a, b, c;
};
IN in[MAXN + 9];
int f[MAXM + 9];
bool cmp(IN x, IN y) {
return x.c < y.c;
}
signed main() {
int n = read(), m = read();
for(int i = 1; i <= n; ++i) {
in[i].a = read();
in[i].b = read();
in[i].c = read();
}
sort(in + 1, in + n + 1, cmp);
int K = 0;
for(int i = 1; i <= n; ++i) {
if(in[i].c != in[i - 1].c) ++K;
sum[K]++;
goods[K][sum[K]].weight = in[i].a, goods[K][sum[K]].value = in[i].b;
}
memset(f, -0x3f, sizeof(f));
f[0] = 0;
for(int k = 1; k <= K; ++k) {
for(int j = m; j >= 0; --j) {
for(int i = 1; i <= sum[k]; ++i) {
if(j >= goods[k][i].weight) f[j] = max(f[j], f[j - goods[k][i].weight] + goods[k][i].value);
}
}
}
int ans = -0x3f3f3f3f;
for(int i = 1; i <= m; ++i) {
if(f[i] > ans) ans = f[i];
}
printf("%d", ans);
return 0;
}
附——背包九讲
背包九讲——例题+解析
普通线形dp
一般的线性dp就是用前面的值来更新后面的值。当我们遇到的题目看似可以用贪心,但是因为有后效性所以被PASS掉;可以用暴力搜索但又会T掉,这时候就可以用dp了。dp类的题目一般没有别的办法,只有多练。
尼克的任务
#include<iostream>
#include<algorithm>
#define MAXN 1000000
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(!isdigit(ch)) {
if(ch == '-') f = -1;
ch = getchar();
}
while(isdigit(ch)) {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
struct st {
int strat, time;
} a[MAXN + 9];
bool cmp(st x, st y) {
return x.strat < y.strat;
}
int f[MAXN + 9];
int main() {
int n = read(), k = read();
for(int i = 1; i <= k; ++i) {
a[i].strat = read(), a[i].time = read();
}
sort(a + 1, a + k + 1, cmp);
int l = k;
f[n + 1] = 0;
for(int i = n; i >= 1; --i) {
if(a[l].strat == i) {
int t2 = l, t1;
while(a[l - 1].strat == a[l].strat) --l;
t1 = l;
int maxi = 0;
for(int j = t1; j <= t2; ++j) {
if(f[i + a[j].time] > maxi) maxi = f[i + a[j].time];
}
f[i] = maxi;
--l;
} else f[i] = f[i + 1] + 1;
}
printf("%d", f[1]);
return 0;
}
[USACO08MAR]River Crossing S
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAXN 2500
#define int long long
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
int sum[MAXN + 9], f[MAXN + 9];
signed main() {
int n = read();
sum[0] = read();
for(int i = 1; i <= n; ++i) {
sum[i] = read();
sum[i] += sum[i - 1];
}
memset(f, 0x3f, sizeof(f));
f[0] = 0;
for(int i = 1; i <= n; ++i) {
for(int k = n; k >= 1; --k) {
f[i] = min(f[i], f[i - k] + sum[k] + sum[0]);
}
}
printf("%lld", f[n] - sum[0]);
return 0;
}
膜拜
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define MAXN 2500
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
int f[MAXN + 9];
int sum1[MAXN + 9], sum2[MAXN + 9];
signed main() {
int n = read(), m = read();
for(int i = 1; i <= n; ++i) {
int a = read();
if(a == 1) sum1[i]++;
if(a == 2) sum2[i]++;
sum1[i] += sum1[i - 1];
sum2[i] += sum2[i - 1];
}
memset(f, 0x3f, sizeof(f));
f[0] = 0;
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= i; ++j) {
if(abs((sum1[i] - sum1[j - 1]) - (sum2[i] - sum2[j - 1])) <= m || sum1[i] - sum1[j - 1] == 0 || sum2[i] - sum2[j - 1] == 0) {
f[i] = min(f[i], f[j - 1] + 1);
}
}
}
printf("%d", f[n]);
return 0;
}
[USACO04NOV]Apple Catching G
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAXN 1000
#define MAXM 30
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
int f[MAXN + 9][MAXM], a[MAXN + 9];
signed main() {
int t = read(), w = read();
for(int i = 1; i <= t; ++i) {
a[i] = read();
}
memset(f, -0x3f, sizeof(f));
if(a[1] == 1) {
f[1][0] = 1;
f[1][1] = 0;
}
else {
f[1][0] = 0;
f[1][1] = 1;
}
for(int i = 2; i <= t; ++i) {
for(int j = 0; j <= w; ++j) {
if(j == 0) f[i][j] = max(f[i][j], f[i - 1][j]);
f[i][j] = max(f[i][j], max(f[i - 1][j], f[i - 1][j - 1]));
int tmp = j % 2 == 1 ? 2 : 1;
if(a[i] == tmp) f[i][j]++;
}
}
int ans = 0;
for(int i = 0; i <= w; ++i) {
ans = max(ans, f[t][i]);
}
printf("%d", ans);
return 0;
}
区间dp
区间dp一般有很强的模板性,即绝大多数区间dp的题目都可以用同一个模板打过去:
for(int len = 1; len <= n; ++len) //枚举区间长度,从小到大
for(int i = 1, j = i + len - 1; j <= n; ++i, ++j) //枚举左右端点
区间dp的主要思维就是用小区间更新到大区间。比如要求\(i-j\)区间的最小更改次数,那么\(f[i][j] = min(f[i][k] + f[k + 1][j])(i <= k < j)\),\(f[i][i]\)可以预处理。
例题
关路灯
#include<iostream>
#include<cstring>
#define MAXN 50
using namespace std;
int read() {
int f = 1, sum = 0;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
sum = sum * 10 + ch - '0';
ch = getchar();
}
return f * sum;
}
int l[MAXN + 9][MAXN + 9], r[MAXN + 9][MAXN + 9], use[MAXN + 9], sub[MAXN + 9];
int main() {
memset(l, 0x3f, sizeof(l));
memset(r, 0x3f, sizeof(r));
int n = read(), c = read();
for(int i = 1; i <= n; ++i) {
sub[i] = read(), use[i] = read();
use[i] += use[i - 1];
}
l[c][c] = r[c][c] = 0;
for(int len = 2; len <= n; ++len) {
for(int i = 1, j = i + len - 1; j <= n; ++i, ++j) {
l[i][j] = min(l[i + 1][j] + (sub[i + 1] - sub[i]) * (use[n] - (use[j] - use[i])), r[i + 1][j] + (sub[j] - sub[i]) * (use[n] - (use[j] - use[i])));
r[i][j] = min(l[i][j - 1] + (sub[j] - sub[i]) * (use[n] - (use[j - 1] - use[i - 1])), r[i][j - 1] + (sub[j] - sub[j - 1]) * (use[n] - (use[j - 1] - use[i - 1])));
}
}
int ans = min(l[1][n], r[1][n]);
printf("%d", ans);
return 0;
}
状压dp
状压, 全称状态压缩, 即把一个很难表示的状态压缩成一维(或多维)数字来进行转移. 通常, 状态压缩使用的都是位运算, 所以最好先掌握位运算.
例题分析(洛谷 P1896 [SCOI2005] 互不侵犯)
在这道题中, 由于它的\(n\)很小, 只有\(8\), 所以每一行我们都可以压缩成一个数字. 具体地, 先看下图:
我们用\(1\)表示有国王或能够被国王攻击的格子, \(0\)表示安全的格子. 那么我们会发现每一行都形成了一个\(01\)串. 由于每一行最多只有\(8\)格, 所以我们把这个\(01\)串当做二进制下的一个数, 转化成十进制后并不会很大, 可以当做数组的下标. 那么我们就可以定义\(f[i][j][k]\)表示到第\(i\)行且当前行状态为\(j\), 已经放置了\(k\)个国王时的可行性. 到这里就差不多了, 可以直接切掉了.
Code:
#include<bits/stdc++.h>
using namespace std;
int ok[109],cnt,king[109];
long long f[10][109][109];
int main() {
int n,K;
scanf("%d%d",&n,&K);
for(int i=0; i<1<<n; ++i) {
if((i&i<<1)==0) {
ok[++cnt]=i;
int t=i;
while(t) king[cnt]++,t-=t&-t;
}
}
f[0][1][0]=1;
for(int i=1; i<=n; ++i) {
for(int pre=1; pre<=cnt; ++pre) {
for(int now=1; now<=cnt; ++now) {
int s2=ok[pre],s1=ok[now];
if(!((s2|s2<<1|s2>>1)&s1))
for(int k=king[now]; k<=K; ++k)
f[i][now][k]+=f[i-1][pre][k-king[now]];
}
}
}
long long ans=0;
for(int i=1; i<=cnt; ++i) {
ans+=f[n][i][K];
}
printf("%lld",ans);
return 0;
}
洛谷 P1879 [USACO06NOV]Corn Fields G
#include<bits/stdc++.h>
using namespace std;
int n,m,cnt;
int c[400],mp[20];
long long f[20][400];
void dfs(int now,int sum) {
if(now>=m) {
c[++cnt]=sum;
return;
}
dfs(now+1,sum);
dfs(now+2,sum+(1<<now));
}
int main() {
scanf("%d%d",&n,&m);
dfs(0,0);
for(int i=1; i<=n; i++) {
for(int j=m; j>=1; j--) {
int x;
scanf("%d",&x);
mp[i]+=x*(1<<(j-1));
}
}
for(int i=1; i<=cnt; i++)
if((mp[1]&c[i])==c[i]) f[1][i]=1;
for(int i=2; i<=n; i++) {
for(int k=1; k<=cnt; k++) {
if((mp[i]&c[k])!=c[k]) continue;
for(int j=1; j<=cnt; j++) {
if((mp[i-1]&c[j])!=c[j]||(c[j]&c[k])!=0) continue;
f[i][k]=(f[i][k]+f[i-1][j])%100000000;
}
}
}
long long ans=0;
for(int i=1; i<=cnt; i++)
if((mp[n]&c[i])==c[i]) ans=(ans+f[n][i])%100000000;
printf("%lld\n",ans);
return 0;
}
洛谷 P2051 [AHOI2009] 中国象棋
#include<bits/stdc++.h>
#define int long long
const int mod=9999973;
using namespace std;
int n,m,ans;
int f[109][109][109];
inline int C(int x) {
return (x*(x-1))>>1;
}
main() {
scanf("%lld%lld",&n,&m);
f[0][0][0]=1;
for(int i=1; i<=n; i++) {
for(int j=0; j<=m; j++) {
for(int k=0; k<=m-j; k++) {
f[i][j][k]=f[i-1][j][k];
if(k>=1)(f[i][j][k]+=f[i-1][j+1][k-1]*(j+1));
if(j>=1)(f[i][j][k]+=f[i-1][j-1][k]*(m-j-k+1));
if(k>=2)(f[i][j][k]+=f[i-1][j+2][k-2]*(((j+2)*(j+1))/2));
if(k>=1)(f[i][j][k]+=f[i-1][j][k-1]*j*(m-j-k+1));
if(j>=2)(f[i][j][k]+=f[i-1][j-2][k]*C(m-j-k+2));
f[i][j][k]%=mod;
}
}
}
for(int i=0; i<=m; i++)
for(int j=0; j<=m; j++)
ans=(ans+f[n][i][j])%mod;
printf("%lld",(ans+mod)%mod);
return 0;
}
洛谷 P2704 [NOI2001] 炮兵阵地
#include<bits/stdc++.h>
using namespace std;
int ok[66], cnt, sum[66], mp[101];
int f[101][66][66], ans;
int main() {
int n, m;
scanf ("%d%d", &n, &m);
for (int i = 0; i < (1 << m); i++)
if (!(i & i << 1) && !(i & i << 2)) {
ok[++cnt] = i;
int t = i;
while (t) {
sum[cnt]++;
t -= t & -t;
}
}
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++) {
char ch;
scanf (" %c", &ch);
if (ch == 'H')
mp[i] |= (1 << (m - j));
}
for (int j = 1; j <= cnt; j++)
if ((ok[j] & mp[1]) == 0)
f[1][j][0] = sum[j];
for (int i = 1; i <= cnt; i++)
if ((ok[i] & mp[2]) == 0)
for (int j = 1; j <= cnt; j++)
if ((ok[j] & mp[1]) == 0 && (ok[i] & ok[j]) == 0)
f[2][i][j] = sum[j] + sum[i];
for (int i = 3; i <= n; i++)
for (int j = 1; j <= cnt; j++)
if ((ok[j] & mp[i]) == 0)
for (int k = 1; k <= cnt; k++)
if ((ok[k] & mp[i - 1]) == 0 && (ok[k] & ok[j]) == 0)
for (int l = 1; l <= cnt; l++)
if ((ok[l] & mp[i - 2]) == 0 && (ok[l] & ok[j]) == 0 && (ok[l] & ok[k]) == 0) {
f[i][j][k] = max(f[i][j][k], f[i - 1][k][l] + sum[j]);
ans = max(ans, f[i][j][k]);
}
printf ("%d", ans);
return 0;
}