[日常训练]string
Description
给定一个长度为\(n\)的字符串,串中的字符保证是前\(k\)个小写字母。你可以在字符串后再添加\(m\)个字符,使得新字符串所包含的不同的子序列数量尽量多。当然,前提是只能添加前\(k\)个小写字母。求新的长度为\(n+m\)的串最多的不同子序列数量。答案对\(10^9+7\)取模。
Input
输入第一行两个数\(m,k\)。
接下来一行一个字符串,长度为\(n\),表示原始的字符串\(s\)。
Output
一个数,表示答案。
Sample Input
1 3
ac
Sample Output
8
HINT
\(n,m\;\leq\;10^6,k\;\leq\;26\)
Solution
当\(m=0\)时,
\(lst[i]\)表示字符\(i\)上一次出现的位置,
\(f[i]\)表示以第\(i\)位结尾的新出现的不同的子序列的个数.
以第\(x(lst[i]\;\leq\;x<i)\)位结尾的新出现的子序列末尾加上\(s[i]\)为一个新的子序列.(反证法可证\(x(0<x<lst[i])\)不可行)
\(f[i]=\sum_{j=lst[s[i]]}^{i-1}f[j]\).
这个可以用前缀和优化.
当\(m\not=0\)时,
设\(sum[i]=\sum_{j=1}^{i}f[j]\),
则\(f[i]=sum[i-1]-sum[lst[j]-1](n<i\;\leq\;n+m)\)
\(f[i]\)最大,即\(lst[j]-1\)最小.
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define K 2000005
#define M 1000000007
using namespace std;
int s[K],sum,tmp,m,n,t;
bool u[K];char c;
int f[K],lst[K];
inline void Aireen(){
scanf("%d%d",&m,&t);
c=getchar();
for(n=1;scanf("%c",&c)==1;++n){
if(!(c>='a'&&c<='z'))
break;
if(lst[c-'a'])
f[n]=(s[n-1]-s[lst[c-'a']-1]+M)%M;
else f[n]=(s[n-1]+1)%M;
s[n]=(s[n-1]+f[n])%M;
lst[c-'a']=n;
}
--n;
if(t) for(int i=n+1,j,k;i<=n+m;++i){
k=lst[0];j=0;
for(int l=1;l<t;++l){
if(lst[l]<k){
k=lst[l];j=l;
}
}
printf("j=%d\n",j);
if(lst[j])
f[i]=(s[i-1]-s[lst[j]-1]+M)%M;
else f[i]=(s[i-1]+1)%M;
s[i]=(s[i-1]+f[i])%M;
lst[j]=i;
}
printf("%d\n",(s[n+m]+1)%M);
}
int main(){
freopen("string.in","r",stdin);
freopen("string.out","w",stdout);
Aireen();
fclose(stdin);
fclose(stdout);
return 0;
}
因为卡空间\(1MB\),每次转移只与\(f[lst[i]-1]\)有关,所以只需\(O(k)\)的空间.
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define K 26
#define M 1000000007
using namespace std;
int s[K],lst[K],sum,tmp,m,n,t;
char c;
inline void Aireen(){
scanf("%d%d",&m,&t);
c=getchar();
for(n=1;scanf("%c",&c)==1;++n){
if(!(c>='a'&&c<='z'))
break;
tmp=s[c-'a'];s[c-'a']=sum;
if(lst[c-'a']) sum=((sum<<1)%M-tmp+M)%M;
else sum=((sum<<1)+1)%M;
lst[c-'a']=n;
}
--n;
if(t) for(int i=1,j,k;i<=m;++i){
k=lst[0];j=0;
for(int l=1;l<t;++l){
if(lst[l]<k){
k=lst[l];j=l;
}
}
tmp=s[j];s[j]=sum;
if(lst[j]) sum=((sum<<1)%M-tmp+M)%M;
else sum=((sum<<1)+1)%M;
lst[j]=i+m;
}
printf("%d\n",(sum+1)%M);
}
int main(){
freopen("string.in","r",stdin);
freopen("string.out","w",stdout);
Aireen();
fclose(stdin);
fclose(stdout);
return 0;
}