Codeforces 1178F DP
题意:有一张白纸条,你需要给这张纸条染色。染色从颜色1开始染色,每次选择纸条的一段染色时,这一段的颜色必须是相同的。现在给你染色后的纸条,问有多少种染色方案?
F1: 思路:最开始的想法是以染色顺序为一个维度,然后染色区间为另外两个维度去DP,但是最后发现不可以,因为之前的所有的染色对后面的影响不确定,只用在染第几种颜色是无法确定现在可以染色的区间的,还无法记忆化,只能去看题解。。。官方题解的DP设计的比较巧妙。我们DP并不关心染色顺序,我们的关注点在区间。我们可以发现,对于一个区间,第一次开始染的颜色一定是标号最小的那种颜色。假设当前区间是[l, r],标号最小的颜色c的位置是p,并且在这个区间内染上颜色c的区间是[a, b],那么区间被分为了4部分:[l, a - 1], [a, p - 1], [p + 1, b], [b + 1, r],递归的求解这四部分,再把答案乘起来就可以了。容易发现a和b的枚举是独立的,那么把枚举的a的答案和枚举的b的答案分别求出,乘起来就可以了。
代码:
#include <bits/stdc++.h> #define INF 0x3f3f3f3f #define pii pair<int, int> #define db double #define LL long long using namespace std; const int maxn = 510; const LL mod = 998244353; LL dp[maxn][maxn]; int n, m; pii c[maxn]; set<int> s, s1; set<int> ::iterator it; int L[maxn], R[maxn]; int cnt[maxn]; int mi[maxn][maxn]; LL dfs(int l, int r) { if(r <= l) return 1; if(dp[l][r] != -1) return dp[l][r]; int p = mi[l][r]; LL sum1 = 0, sum2 = 0; for (int i = l - 1; i < p; i++) { sum1 += (dfs(l, i) * dfs(i + 1, p - 1)) % mod; } sum1 %= mod; for (int j = p; j <= r; j++) { sum2 += (dfs(p + 1, j) * dfs(j + 1, r)) % mod; } sum2 %= mod; dp[l][r] = (sum1 * sum2) % mod; return dp[l][r]; } int main() { scanf("%d%d", &n, &m); memset(dp, -1, sizeof(dp)); for (int i = 1; i <= n; i++) { scanf("%d", &c[i].first); c[i].second = i; L[i] = 1, R[i] = n; } s.insert(0); s.insert(n + 1); sort(c + 1, c + 1 + n); for (int i = 1; i <= n; i++) { L[i] = (*(--s.lower_bound(c[i].second))); R[i] = (*s.lower_bound(c[i].second)); L[i]++, R[i]--; for (int j = L[i]; j <= c[i].second; j++) for (int k = c[i].second; k <= R[i]; k++) mi[j][k] = c[i].second; s.insert(c[i].second); } cout << dfs(1, n) << endl; }
F2: 现在纸条的范围是1e6,不能直接区间DP了。不过通过观察发现,如果纸条上有多个相邻的位置颜色相同,那么可以缩成一个位置,因为这些位置整体要么都不覆盖,要么全部覆盖,和一个位置的贡献相同。容易发现如果方案合法缩点之后点数不会超过1000,所有超过1000的可以直接输出0。缩点之后和F1已经很像了,不过有些小细节需要注意。首先需要判断方案是否合法,容易发现如果有两个颜色相同的位置中间有标号比它们小的颜色,方案不合法。这个直接暴力判断就可以了。其次,在DP计算方案时,除了F1中的四个部分,区间内有可能有多最小标号的位置,这些位置之间的区间也要乘起来。最后,我们不能直接像F1那样枚举区间,比如这种例子:
8 10
8 4 3 5 7 5 2 6 2 1
容易发现,我们在枚举2的覆盖区间时,覆盖区间的左端点不能在2个5之间(包括最右边的5),如果2在这里覆盖了,5这种情况是不可能出现的。这个用指针之类的处理一下就可以了。
代码:
#include <bits/stdc++.h> #define INF 0x3f3f3f3f #define pii pair<int, int> #define db double #define LL long long using namespace std; const int maxn = 1010; const LL mod = 998244353; LL dp[maxn][maxn]; int n, m; int c[1000010]; int L[maxn], R[maxn]; int cnt[maxn], last[maxn], Next[maxn]; int mi[maxn][maxn]; vector<pii> re[maxn]; LL dfs(int l, int r) { if(r <= l) return 1; if(dp[l][r] != -1) return dp[l][r]; int lp = L[mi[l][r]], rp = R[mi[l][r]]; LL sum1 = 0, sum2 = 0, sum3 = 1; int now = mi[l][r]; for (int i = 0; i < re[now].size(); i++) { sum3 = (sum3 * dfs(re[now][i].first, re[now][i].second)) % mod; } for (int i = l - 1; i < lp; i = Next[i]) { sum1 += (dfs(l, i) * dfs(i + 1, lp - 1)) % mod; } sum1 %= mod; for (int j = rp; j <= r; j = Next[j]) { sum2 += (dfs(rp + 1, j) * dfs(j + 1, r)) % mod; } sum2 %= mod; dp[l][r] = (((sum1 * sum2) % mod) * sum3) % mod; return dp[l][r]; } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= m; i++) { scanf("%d", &c[i]); } int tot = 0; tot = 1; for (int i = 2; i <= m; i++) { if(c[i] == c[i - 1]) { continue; } else { c[++tot] = c[i]; } } if(tot > 1000) { printf("0\n"); return 0; } memset(dp, -1, sizeof(dp)); memset(last, -1, sizeof(last)); for (int i = 1; i <= tot; i++) { mi[i][i] = c[i]; for (int j = i + 1; j <= tot; j++) { mi[i][j] = min(c[j], mi[i][j - 1]); } L[i] = tot + 1; R[i] = 0; } bool flag = 0; for (int i = 1; i <= tot; i++) { L[c[i]] = min(L[c[i]], i); R[c[i]] = max(R[c[i]], i); if(last[c[i]] != -1) { if(mi[last[c[i]] + 1][i - 1] < c[i]) { flag = 1; break; } re[c[i]].push_back(make_pair(last[c[i]] + 1, i - 1)); } last[c[i]] = i; } for (int i = 0; i <= tot; i++) { Next[i] = R[c[i + 1]]; } Next[tot] = tot + 1; if(flag) { printf("0\n"); return 0; } cout << dfs(1, tot) << endl; return 0; }