PTA L3-020 至多删三个字符 (序列dp/序列自动机)
给定一个全部由小写英文字母组成的字符串,允许你至多删掉其中 3 个字符,结果可能有多少种不同的字符串?
输入格式:
输入在一行中给出全部由小写英文字母组成的、长度在区间 [4, 1] 内的字符串。
输出格式:
在一行中输出至多删掉其中 3 个字符后不同字符串的个数。
输入样例:
ababcc
输出样例:
25
提示:
删掉 0 个字符得到 "ababcc"。
删掉 1 个字符得到 "babcc", "aabcc", "abbcc", "abacc" 和 "ababc"。
删掉 2 个字符得到 "abcc", "bbcc", "bacc", "babc", "aacc", "aabc", "abbc", "abac" 和 "abab"。
删掉 3 个字符得到 "abc", "bcc", "acc", "bbc", "bac", "bab", "aac", "aab", "abb" 和 "aba"。
解法:
前置技能:求一个序列中所有的不同子序列个数。
eg:FZU - 2129
设dp[i]为序列a的前i个元素所组成的不同子序列个数,则有状态转移方程:$dp[i]=\left\{\begin{matrix}\begin{aligned}&2dp[i-1]+1,pre[a[i]]=-1\\&2dp[i-1]-dp[pre[a[i]]-1],pre[a[i]]\neq -1\end{aligned}\end{matrix}\right.$
其中pre[a[i]]表示a[i]前面第一个和a[i]相同的元素的下标。
解释:第i个元素a[i]有两种选择:选或不选。
若不选a[i],则dp[i]继承dp[i-1]的全部子序列,因此有dp[i]+=dp[i-1]。
若选a[i],则dp[i]在dp[i-1]的全部子序列的尾部填加了个元素a[i],因此仍有dp[i]+=dp[i-1]。但这样会有很多重复的序列,因此要去重,即去掉前面和a[i]相同的元素之前的序列(因为它们加上a[i]形成的序列已经被算过了),因此有dp[i]-=dp[pre[a[i]]-1]。特别地,如果a[i]前面没有与a[i]相同的元素,那么没有重复的序列,并且a[i]自己单独形成一个新序列,此时dp[i]++。
1 #include<cstdio> 2 #include<cstring> 3 using namespace std; 4 typedef long long ll; 5 typedef double db; 6 const int N=1e6+10,mod=1e9+7; 7 int a[N],n,dp[N],pre[N]; 8 int main() { 9 while(scanf("%d",&n)==1) { 10 memset(pre,-1,sizeof pre); 11 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 12 dp[0]=0; 13 for(int i=1; i<=n; ++i) { 14 dp[i]=(ll)dp[i-1]*2%mod; 15 if(~pre[a[i]])dp[i]=((ll)dp[i]-dp[pre[a[i]]-1])%mod; 16 else dp[i]=(dp[i]+1)%mod; 17 pre[a[i]]=i; 18 } 19 printf("%d\n",(dp[n]+mod)%mod); 20 } 21 return 0; 22 }
回到正题,此题是上题的升级版,等价于求一个长度为n的序列中长度为n,n-1,n-2,n-3的不同子序列个数之和。
基本思路是一致的,只需要在上述代码的基础上稍作改动即可。
设dp[i][j]为前i个元素删了j个元素所形成的子序列个数,则有$dp[i][j]=\left\{\begin{matrix}\begin{aligned}&dp[i-1][j-1]+dp[i-1][j],pre[a[i]]=-1,j\neq i-1\\&dp[i-1][j-1]+dp[i-1][j]+1,pre[a[i]]=-1,j=i-1\\&dp[i-1][j-1]+dp[i-1][j]-dp[pre[a[i]]-1][j-(i-pre[a[i]])],pre[a[i]]\neq -1\end{aligned}\end{matrix}\right.$
推导过程类似,注意j的变化即可。
1 #include<cstdio> 2 #include<cstring> 3 using namespace std; 4 typedef long long ll; 5 typedef double db; 6 const int N=1e6+10; 7 char a[N]; 8 int n,pre[300]; 9 ll dp[N][4]; 10 int main() { 11 memset(pre,-1,sizeof pre); 12 scanf("%s",a+1),n=strlen(a+1); 13 for(int i=1; i<=n; ++i) { 14 for(int j=0; j<=3; ++j) { 15 if(j>0)dp[i][j]+=dp[i-1][j-1]; 16 dp[i][j]+=dp[i-1][j]; 17 if(~pre[a[i]]&&j>=i-pre[a[i]])dp[i][j]-=dp[pre[a[i]]-1][j-(i-pre[a[i]])]; 18 else if(i==j+1)dp[i][j]++; 19 } 20 pre[a[i]]=i; 21 } 22 printf("%lld\n",dp[n][0]+dp[n][1]+dp[n][2]+dp[n][3]); 23 return 0; 24 }
还有另一种解法是利用序列自动机,很简单,设go[i][j]为第i个元素后第一个元素j出现的位置,先用类似dp的方式建立自动机,则问题转化成了一个DAG上的dp问题。
但是由于序列自动机空间消耗较大,直接dfs可能会爆内存,比如这样:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef double db; 5 const int N=1e6+10,M=26; 6 char s[N]; 7 int n,go[N][M],dp[N][4]; 8 void build() { 9 memset(go[n],0,sizeof go[n]); 10 for(int i=n-1; i>=0; --i)memcpy(go[i],go[i+1],sizeof go[i]),go[i][s[i]-'a']=i+1; 11 } 12 int dfs(int u,int k) { 13 if(k>3)return 0; 14 int& ret=dp[u][k]; 15 if(~ret)return ret; 16 ret=(k+(n-u)<=3); 17 for(int i=0; i<M; ++i)if(go[u][i])ret+=dfs(go[u][i],k+go[u][i]-u-1); 18 return ret; 19 } 20 int main() { 21 scanf("%s",s),n=strlen(s); 22 build(); 23 memset(dp,-1,sizeof dp); 24 printf("%d\n",dfs(0,0)); 25 return 0; 26 }
解决方法是自底而上,一遍dp一遍更新go数组,成功AC:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef double db; 5 const int N=1e6+10,M=26; 6 char s[N]; 7 int n,go[M]; 8 ll dp[N][4]; 9 int main() { 10 scanf("%s",s),n=strlen(s); 11 dp[n][0]=dp[n][1]=dp[n][2]=dp[n][3]=1; 12 for(int i=n-1; i>=0; --i) { 13 go[s[i]-'a']=i+1; 14 for(int j=0; j<=3; ++j) { 15 dp[i][j]=(j+(n-i)<=3); 16 for(int k=0; k<M; ++k)if(go[k]&&j+go[k]-i-1<=3)dp[i][j]+=dp[go[k]][j+go[k]-i-1]; 17 } 18 } 19 printf("%lld\n",dp[0][0]); 20 return 0; 21 }
虽然序列自动机的功能比较强大,但时间和空间的消耗都与元素集合的大小有关,因此当元素集合过大的时候,可能就并不吃香了~~