Loading

2023 6月 dp做题记录

6月 dp做题记录

P5664 [CSP-S2019] Emiya 家今天的饭

分析条件,我们要选出来的菜的集合需要满足的限制,集合不为空和烹饪方法互不相同都好处理,这样保证每种烹饪方法是独立不受影响的,并且至多选一种,所以每种烹饪方法 \(i\) 选菜的方案为 \(sum_i=\sum\limits_{j=1}^m a_{i,j}\),总方案就为 \(\sum\limits_{i=1}^n sum_i-1\),减一为集合为空的情况。

在第三种限制里,集合中每种食材的使用次数不超过 \(\left\lfloor\dfrac{k}{2}\right\rfloor\) 次,若是直接顺着计算,肯定不好求,因为不超过的方案对比超过太复杂了。正难则反,考虑容斥,我们前面求出了不考虑第三种限制的方案数,只要我们求出了不符合第三种限制的方案数,两个相减即可。

考虑用动态规划,我们在考虑不符合第三种限制是,同时也要满足前两种限制,这样保证求出来的方案一定在总方案数中。最朴素的,设状态 \(dp_{i,j,k}\) 为前 \(i\) 种烹饪方法中选了 \(j\) 道菜,其中 \(k\) 道菜是第 \(g\) 种食材做的。这里需要枚举 \(i\)\(j\)\(k\)\(g\) 四个量,转移很好想

\(dp_{i,j,k}=dp_{i-1,j,k}+dp_{i-1,j-1,k}\times(sum_i-a_{i,g})+dp_{i-1,j-1,k-1}\times a_{i,g}\)

分为不选第 \(i\) 种烹饪方法,选了但不是第 \(g\) 种,选了是第 \(g\) 种。先枚举 \(g\),每次累加,这样的不合法方案数为 \(\sum\limits_{k}dp_{n,j,k}\),这里的 \(k>\left\lfloor\dfrac{j}{2}\right\rfloor\) 。这样的复杂度是 \(O(n^3m)\),通过不了此题。

但我们再思考一下,在一种不合法的方案中不合法的食材有且仅有一种,因为假设有两种,一定超过总的选菜数量,即\(2\times(\left\lfloor\dfrac{j}{2}\right\rfloor+1)>j\)。所以等价于 \(k>\dfrac{j}{2}\),化简得到 \(k-(j-k)>0\),感性理解就是当前不合法食材数减去合法食材数大于 \(0\),这样不满足第三种限制的方案只需要不满足这个限制就行了。(这里也可以考虑感性理解,超过一半的食材肯定不会有两个,并且一次只会有一种不合法食材,就有了如果不合法的食材比合法食材还多,那么就是不满足限制的)

放在状态里面就是不关心 \(j\)\(k\),只关心不合法与合法之间的差值,这样状态就可以简化成 \(dp_{i,j}\) 表示前 \(i\) 种烹饪方法,差值为 \(j\) 的方案数,这里的 \(j\) 可能是负数,所以加一个 \(n\) 来保证是正的(差值最大为 \(n\))。状态转移为:

\(dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1}\times a_{i,g}+dp_{i-1,j+1}\times(sum_i-a_{i,g})\)

每次枚举一个 \(g\),计算当前情况下的不合法数,即 \(\sum\limits_{j>0}dp_{n,j}\),不同 \(g\) 之间不会重合,所以每处理一次就让总方案数减去它即可。

#include<bits/stdc++.h>
using namespace std;
int n,m,a[120][2020];
long long s[120],dp[120][220],ans=1;
const int mod=998244353; 
int main(){
	cin >> n >> m;
	for(int i = 1; i <= n; i++){
		for(int j = 1; j <= m; j++){
			cin >> a[i][j];
			s[i] += a[i][j];
			s[i] %= mod;
		}
	}
	for(int i = 1; i <= n; i++){
		ans *= (s[i] + 1);
		ans %= mod;
	}
	ans = (ans - 1 + mod) % mod;
	for(int k = 1; k <= m; k++){
		memset(dp, 0, sizeof(dp));
		dp[0][100] = 1;
		for(int i = 1; i <= n; i++){
			long long now = s[i] - a[i][k];
			for(int j = 0; j <= n + 100; j++){
				dp[i][j] = dp[i - 1][j];
				if(j) dp[i][j] += dp[i - 1][j - 1] * a[i][k];
				dp[i][j] += dp[i - 1][j + 1] * now;
				dp[i][j] %= mod;
			}
		}
		for(int i = 101; i <= n + 100; i++){
			ans -= dp[n][i];
			ans = (ans + mod) % mod;
		}
	}
	cout << ans << endl;
	return 0;
}

P8867 [NOIP2022] 建造军营

计算合法的建造军营和看守道路方案数,合法即为去掉一条没人看守的边后军营之间依然连通,因为是一条,所以容易发现在图中,强连通分量的边被割去一条是一定不会影响军营连通的,即强连通分量的边想看守就看守,不作为决定性因素。只有割边与方案的合法性有关。

所以我们考虑缩点,在无向图缩点后,原图会变成一个树,这点方便我们做树形 dp。

每个强连通分量内的方案数是可以预处理的,处理出点数为 \(v_i\),边数为 \(e_i\)。那么在这个强连通分量中,不选军营的方案数是 \(2^{e_i}\),选至少一个军营的方案数为 \(2^{v_i+e_i}-2^{e_i}\)

考虑题目,经过上面分析,题目简化成,在一颗树上选出若干个点,选出的点在去掉一条边后依然连通的方案数。意思即选出的点之间相连的唯一路径上的边一定要看守,其他随意。问题缩小到子树上,限制只在子树里有军营,由于枚举子树时,唯一不同的就是根节点,它是我们区分不同方案的关键,所以我们限制当前的子树的根节点一定为相连路径上的一点,这点方便转移,因为这样子树中的军营就可以通过根节点相连。

在限制下,考虑 \(u\) 节点和它的儿子 \(v\),它们之间相不相连取决于 \(v\) 的子树中有无军营。顺着这个可以设出状态 \(dp_{u,0/1}\) 表示在以 \(u\) 节点为根的子树中的没有/有军营的方案数。根据枚举儿子节点顺序会有前 \(i\) 个儿子的隐藏状态,类似背包,分为当前子节点选不选节点,转移可以写成:

\(\begin{cases}dp_{u,1}=dp_{u,0}\times dp_{v,1}+dp_{u,1}\times(2\times dp_{v,0}+dp_{v,1})\\dp_{u,0}=dp_{u,0}\times (2\times dp_{v,0})\end{cases}\)

\(2\) 的地方是因为这里的边 \((u,v)\) 由于 \(v\) 中没有军营可以选和不选。考虑了子树内的方案数,并且前面为了统计答案,限制了军营只在子树内,所以子树外的边可以随便选。这里要注意的就是在统计答案上如何保证不重不漏,计算出 \(u\) 子树的方案后,我们回到了 \(fa_u\) 子树,为了不重复,\((fa_u,u)\) 这条边就会选入,这是和 \(u\) 子树方案根本的区别,所以在计入子树 \(u\) 的答案时,先预处理出 \(sz_i\) 表示 \(i\) 子树中的边数(包括强连通分量的边),原本为 \(dp_{u,1}\times 2^{sz_1-sz_u}\),但这其中会多计算一次和 \(fa_u\) 相连的方案,需要改成 \(dp_{u,1}\times 2^{sz_1-sz_u-1}\) 保证不重。

答案即为:

\(\begin{cases}ans\leftarrow dp_{u,1}&u=1\\ans\leftarrow dp_{u,1}\times 2^{sz_1-sz_u-1}&u\ne 1\end{cases}\)

要理解不漏也很容易,我们将每个节点作为中转点,实际上所有的选点方案都一定至少会在一棵子树上被统计。

在这题中,我们经过了缩点,将题目转化为树形 dp,统计方案数,在树上选点可以依照题意给出可以转移的状态,并且为了统计方案,可以给状态一些隐藏的限制,一是便于转移,二是可以使状态的意义更加明晰,特指某一种情况下的方案,使得小的方案之间没有并集,便于统计答案。最后的复杂度为 \(O(n+m)\)

#include<bits/stdc++.h>
using namespace std;
int read(){
	int x = 0, f = 1;
	char c = getchar();
	while(c < '0' || c > '9'){
		if(c == '-') f = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9'){
		x = (x << 1) + (x << 3) + (c - '0');
		c = getchar();
	}
	return x * f;
}
const int mod = 1000000007;
int n, m, cnt, cnt2, top, idx, tot;
int h[500010], h2[500010];
long long g[500010], sz[500010], sum1[500010];
int low[500010], dfn[500010], bel[500010], ins[500010], st[500010];
long long dp[500010][2], ans;
struct node{
	int to, nxt;
}e[2000010];
struct node2{
	int to, nxt;
}e2[2000010];
void add(int u, int v){
	e[++cnt].to = v;
	e[cnt].nxt = h[u];
	h[u] = cnt;
}
void add2(int u, int v){
	e2[++cnt2].to = v;
	e2[cnt2].nxt = h2[u];
	h2[u] = cnt2;
}
void tarjan(int u, int fa){
	dfn[u] = low[u] = ++tot;
	st[++top] = u;
	ins[u] = 1;
	for(int i = h[u]; i; i = e[i].nxt){
		int v = e[i].to;
		if(!ins[v]){
			tarjan(v, u);
			low[u] = min(low[u], low[v]);
		}
		else if(v != fa){
			low[u] = min(low[u], dfn[v]);
		}
	}
	if(low[u] == dfn[u]){
		++idx;
		int v;
		do{
			v = st[top--];
			bel[v] = idx;
			sum1[idx]++;
			ins[v] = 0;
		}while(v != u);
	}
}
long long ksm(long long a, long long b){
	long long ans = 1;
	while(b){
		if(b & 1) ans = (ans * a) % mod;
		a = (a * a) % mod;
		b >>= 1;
	}
	return ans;
}
void init(int u, int fa){
	sz[u] = g[u];
	for(int i = h2[u]; i; i = e2[i].nxt){
		int v = e2[i].to;
		if(v == fa) continue;
		init(v, u);
		sz[u] += sz[v] + 1;
	}
}
void dfs(int u, int fa){
	dp[u][0] = ksm(2, g[u]) % mod, dp[u][1] = (ksm(2, sum1[u] + g[u]) - dp[u][0] + mod) % mod;
	for(int i = h2[u]; i; i = e2[i].nxt){
		int v = e2[i].to;
		if(v == fa) continue;
		dfs(v, u);
		dp[u][1] = (dp[u][1] * (2 * dp[v][0] % mod + dp[v][1]) % mod + dp[u][0] * dp[v][1] % mod) % mod;
		dp[u][0] = dp[u][0] * (2 * dp[v][0] % mod) % mod;
	}
	if(u == 1) ans += dp[u][1], ans %= mod;
	else ans += (dp[u][1] * ksm(2, sz[1] - sz[u] - 1) % mod) % mod, ans %= mod;
}
int main(){
	n = read(), m = read();
	for(int i = 1; i <= m; i++){
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	tarjan(1, 0);
	for(int i = 1; i <= n; i++){
		for(int j = h[i]; j; j = e[j].nxt){
			int v = e[j].to;
			if(bel[i] == bel[v]) g[bel[i]]++;
			else add2(bel[i], bel[v]);
		}
	}
	for(int i = 1; i <= idx; i++) g[i] /= 2;
	init(1, 0);
	dfs(1, 0);
	cout << ans << endl;
	return 0;
}

[ARC115E] LEQ and NEQ

如果不考虑第二个条件的话,那么答案显然是 \(\sum a_i\),所以我们考虑容斥掉不合法的方案。

容斥的基本条件,我们要找到一个共性,也就是能够容斥的性质。这一题中,容斥的对象就是不符合第二个条件的两项。所以我们可以设 \(g_i\) 为刚好有 \(i\) 组违反条件的项,\(f_i\) 为至少有 \(i\) 组违反条件的项。我们直接套用容斥的公式。

\[ans=\sum\limits_{i=0}^{n-1}(-1)^i\ f_i \]

处理 \(f_i\) 的过程需要用到动态规划。我们发现不同的数字代表一段区间,彼此之间相邻的违反条件的点也正好对应一段区间,坏点越多序列段数越少,坏点的增多会导致段数的减少,所以反过来,我们可以用段数的多少来满足”至少“这个条件(比如 \(5\) 个分成 \(3\) 段,说明至少有 \(2\) 个坏点),且分段问题更好解决。于是我们设状态 \(dp_{i,j}\) 为前 \(i\) 个正好分了 \(j\) 段的方案数。可能会觉得这不就和 \(f_i\)至少冲突了吗?我们这里的分段并不是严格意义上的分段,我们只规定了段内一定相同,而相邻的虽然不是一段但也可以相同。

因为我们转移需要一整段\(a_i\) 的大小关系,又因为不是严格分段,不需要考虑段之间是否一定不同,所以转移为

\(dp_{i,j}=\sum\limits_{k=0}^{i-1}dp_{k,j-1}\times \min\limits_{o=k+1}^ia_o\)

统计答案也就变成

\(ans=\sum\limits_{i=0}^{n-1}(-1)^i\ dp_{n,n-i}\)

发现 \(j\) 位置只需要 \(j\)\(j-1\),即与当前奇偶性有关,在奇偶之间转移,所以可以用滚动数组降维。降维之后统计答案也方便,因为统计答案时同样只需要关心奇偶性。

\(dp_{i,0/1}=\sum\limits_{k=0}^{i-1}dp_{k,1/0}\times \min\limits_{o=k+1}^ia_o\)

这样转移是 \(O(n^2)\) 的,需要优化转移。每枚举一个 \(i\),就要多考虑一个 \(a_i\),并且对于连续区间的最小值,它的转移的贡献也是连续的,往前枚举 \(j\) 的过程中,会有一个时刻,\(a_i\) 永远不会是之后的最小值,这个时候即为左边第一个小于 \(a_i\) 的数。以这个时刻为分隔点 \(k\)\(k\) 及它右边的贡献乘的都是 \(a_i\),而 \([k,i-1]\)\(dp\) 值可以通过前缀和统计;左边的贡献由于已经不受影响新枚举的 \(a_i\),可以发现总贡献之前已经处理过了,即为 \(dp_{k,0/1}\)

关于找到左边第一个小于 \(a_i\) 的位置,可以用单调栈实现。复杂度就降到 \(O(n)\)

#include <bits/stdc++.h>
using namespace std;
long long read(){
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c - '0');
        c = getchar();
    }
    return x * f;
}
long long n, ans, mod = 998244353;
long long a[500010], s[500010][2], dp[500010][2];
long long st[500010], top;
int main(){
    n = read();
    for(int i = 1; i <= n; i++){
        a[i] = read();
    }
    dp[0][0] = s[0][0] = 1;
    for(int i = 1; i <= n; i++){
        while(top > 0 && a[st[top]] >= a[i]) top--;
        if(!top){
            for(int j = 0; j <= 1; j++) dp[i][j] = 1ll * (dp[i][j] + s[i - 1][j ^ 1] * a[i]) % mod;
        }
        else{
            for(int j = 0; j <= 1; j++) dp[i][j] = 1ll * (dp[st[top]][j] + (s[i - 1][j ^ 1] - s[st[top] - 1][j ^ 1] + mod) * a[i]) % mod;
        }
        s[i][0] = (s[i - 1][0] + dp[i][0]) % mod;
        s[i][1] = (s[i - 1][1] + dp[i][1]) % mod;
        st[++top] = i;
    }
    if(n % 2 == 1) ans = (dp[n][1] - dp[n][0] + mod) % mod;
    else ans = (dp[n][0] - dp[n][1] + mod) % mod;
    cout << ans;
    return 0;
}

P3800 Power收集

这题的状态很明确,因为完全可以把每一个网格看成状态,网格之间相互转移,所以设状态 \(dp_{i,j}\) 为走到第 \(i\) 行第 \(j\) 列时的最大值。依照题意,一层的转移只与上一层有关,所以转移为

\(dp_{i,j}=\max\limits_{j-t\le k\le j+t}(dp_{i-1,k})+a_{i,j}\)

典型的单调队列形式,我们需要的只有一段区间中的最大值,并且区间移动是连续的。瓶颈在于我们枚举 \(j\) 的时候,我们只能知道 \([j-t,j]\) 的最大值。解决方法很简单,最大值有结合律,即 \(\max(a,b,c)=\max(a,(b,c))\),所以我们正着和反着都做一遍,把 \([j-t,j]\)\([j,j+t]\) 的最大值分别求出来,跑两遍单调队列即可。

其他还有可以优化的地方,比如由于一层的转移只与上一层有关,所以可以滚掉 \(i\) 这一维。

P3594 [POI2015] WIL

这题中,可以一次将任意长度小于等于 \(d\) 的区间变为 \(0\),求修改完之后区间和小于 \(p\) 的最长区间长度。

对于区间和,我们可以用前缀和 \(sum_x=\sum\limits_{i=1}^xa_i\),来 \(O(1)\) 求出。

可以容易想到,“任意长度小于等于 \(d\) 的区间” 在实际操作中一定是贪心地取长度为 \(d\) 的区间,因为多取一定不劣。所以一个暴力做法是,我们枚举区间的左右断点 \(l\)\(r\),再枚举一个 在 \([i,j]\) 之间的 \(k\) 为修改的区间右端点,判断减去一段区间后的区间和是否小于 \(q\) 来更新答案。复杂度 \(O(n^3)\)

考虑优化,对于一个左端点,我们一定是找它满足条件的最远右端点;同样,对于一个右端点,我们一定是找它满足条件的最远左端点。这个性质可以用上双指针,只需要枚举 \(r\)\(k\)\(l\) 只需要根据区间和单调向右移动即可。复杂度 \(O(n^2)\)

现在的瓶颈是,因为我们判断一个区间能否满足条件,一定要找到它的最大修改区间才能一次做出决定,所以能否优化掉寻找最大修改区间的时间呢?可以用到单调队列,每枚举一个 \(i\),就多一个区间 \([i-d,i]\),所以我们维护当前满足条件的最大修改区间,判断时直接取出即可。这里的条件指当前的 \(l\) 指针是否已经超过当前队首修改区间的左端点

如果此时的 \([l,r]\) 用上最大修改区间还是大于 \(p\) 的话,那么 \(l\) 只能往右走,并同时删去由于 \(l\) 向右走而导致不合法的修改区间,保证下一次取出的最大修改区间是合法的,由于单调性,我们不用担心区间会不会被多删。

由于一个数最多进出队列一次,并且 \(l\) 单调向右移动,所以单调队列和双指针都是线性的,复杂度降到 \(O(n)\)

#include <bits/stdc++.h>
using namespace std;
int read(){
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c - '0');
        c = getchar();
    }
    return x * f;
}
int n, m, k, t, ans;
int dp[4010][4010], a[4010][4010], q[4010];
int main(){
    n = read(), m = read(), k = read(), t = read();
    for(int i = 1; i <= k; i++){
        int x = read(), y = read(), v = read();
        a[x][y] = v;
    }
    for(int i = 1; i <= n; i++){
        int head = 1, tail = 0;
        for(int j = 1; j <= m; j++){
            while(head <= tail && dp[i - 1][q[tail]] <= dp[i - 1][j]) tail--;
            while(head <= tail && q[head] + t < j) head++;
            q[++tail] = j;
            dp[i][j] = max(dp[i][j], dp[i - 1][q[head]] + a[i][j]); 
        }
        head = 1, tail = 0;
        for(int j = m; j >= 1; j--){
            while(head <= tail && dp[i - 1][q[tail]] <= dp[i - 1][j]) tail--;
            while(head <= tail && q[head] - t > j) head++;
            q[++tail] = j;
            dp[i][j] = max(dp[i][j], dp[i - 1][q[head]] + a[i][j]); 
        }
    }
    for(int i = 1; i <= m; i++) ans = max(ans, dp[n][i]);
    cout << ans << endl;
    return 0;
}
posted @ 2024-04-20 11:36  Fire_Raku  阅读(5)  评论(0编辑  收藏  举报