Nowcoder Removal ( 字符串上的线性 DP )
题意 : 给出长度为 n 的字符串、问你准确删除 m 个元素之后、能产生多少种不同的子串
分析 ( 参考博客 ):
可以考虑线性 DP 解决这个问题
试着如下定义动态规划数组
dp[i][j] = 在加入第 i 个字符串后、总共删除了 j 个字符后的不同子串的个数
不难写出状态转移方程 dp[i][j] = dp[i-1][j] + dp[i-1][j-1]
代表在 i 这个字符加入之后、在删除总次数为 j 的情况下是否删除 i 的两个状态转移而来
但是这样子势必会有重复的字串出现、例如 aaa 这个字符串 删除第一个和第二个都产生 aa
所以需要减去重复的部分、注意到如果在第 i 个字符之前后一个和它相同的字符出现记做 pre_i
而且满足 (i字符位置) - (pre_i字符位置) <= j 则说明会有重复计算的情况
因为如果在总共删除的 j 个字符中包含了pre_i ~ i 之间的字符、那么便会产生重复
例如 abcexxe 发现有一前一后两个 e 两个之间相隔有两个 x
那么如果用上述的转移方程去更新的话 dp[4][0] = "abce" 且 dp[7][3] = 包含了 "abce"
所以会有诸如这样子的重复、减去就行了
怎么减? if (i字符位置) - (pre_i字符位置) <= j 执行 dp[i][j] -= dp[ pre_i - 1 ][ j - ( i - pre_i ) ]
因为先前重复计算的部分就包含了 pre_i 位置下减去了 j - ( i - pre_i ) 个字符的部分
#include<bits/stdc++.h> #define LL __int64 #define ULL unsigned long long #define scl(i) scanf("%lld", &i) #define scll(i, j) scanf("%lld %lld", &i, &j) #define sclll(i, j, k) scanf("%lld %lld %lld", &i, &j, &k) #define scllll(i, j, k, l) scanf("%lld %lld %lld %lld", &i, &j, &k, &l) #define scs(i) scanf("%s", i) #define sci(i) scanf("%d", &i) #define scd(i) scanf("%lf", &i) #define scIl(i) scanf("%I64d", &i) #define scii(i, j) scanf("%d %d", &i, &j) #define scdd(i, j) scanf("%lf %lf", &i, &j) #define scIll(i, j) scanf("%I64d %I64d", &i, &j) #define sciii(i, j, k) scanf("%d %d %d", &i, &j, &k) #define scddd(i, j, k) scanf("%lf %lf %lf", &i, &j, &k) #define scIlll(i, j, k) scanf("%I64d %I64d %I64d", &i, &j, &k) #define sciiii(i, j, k, l) scanf("%d %d %d %d", &i, &j, &k, &l) #define scdddd(i, j, k, l) scanf("%lf %lf %lf %lf", &i, &j, &k, &l) #define scIllll(i, j, k, l) scanf("%I64d %I64d %I64d %I64d", &i, &j, &k, &l) #define lson l, m, rt<<1 #define rson m+1, r, rt<<1|1 #define lowbit(i) (i & (-i)) #define mem(i, j) memset(i, j, sizeof(i)) #define fir first #define sec second #define VI vector<int> #define ins(i) insert(i) #define pb(i) push_back(i) #define pii pair<int, int> #define VL vector<long long> #define mk(i, j) make_pair(i, j) #define all(i) i.begin(), i.end() #define pll pair<long long, long long> #define _TIME 0 #define _INPUT 0 #define _OUTPUT 0 clock_t START, END; void __stTIME(); void __enTIME(); void __IOPUT(); using namespace std; const int maxn = 1e5 + 10; const int mod = 1e9 + 7; int arr[maxn]; int Pre[maxn]; int Last[10 + 5]; int dp[maxn][10 + 5]; int main(void){__stTIME();__IOPUT(); int n, m, k; while(~sciii(n, m, k)){ for(int i=0; i<=n; i++) Last[i] = 0; for(int i=1; i<=n; i++){ sci(arr[i]); Pre[i] = Last[arr[i]]; Last[arr[i]] = i; } for(int i=0; i<=m; i++) dp[i][i] = 1; ///代表空串、在下面的转移方程中 dp[i-1][j] 会用到 ///其意义是更新到 i 为止在不删除第 i 个字符情况下总共删除了 j 个 ///字符的情况、此时就说明只剩下第 i 个字符、dp数值应该为 1 ///故给空串赋值为 1 for(int i=1; i<=n; i++){ dp[i][0] = 1; for(int j=1; j<=min(m, i-1); j++){ dp[i][j] = (dp[i-1][j-1]%mod + dp[i-1][j]%mod) % mod; if(Pre[i] && i - Pre[i] <= j){ dp[i][j] = (dp[i][j] - dp[ Pre[i] - 1 ][ j - (i - Pre[i]) ] + mod ) % mod; } } } printf("%d\n", dp[n][m]%mod); } __enTIME();return 0;} void __stTIME() { #if _TIME START = clock(); #endif } void __enTIME() { #if _TIME END = clock(); cerr<<"execute time = "<<(double)(END-START)/CLOCKS_PER_SEC<<endl; #endif } void __IOPUT() { #if _INPUT freopen("in.txt", "r", stdin); #endif #if _OUTPUT freopen("out.txt", "w", stdout); #endif }