【LOJ6074】【2017 山东一轮集训 Day6】子序列 DP
题目描述
有一个由前 \(m\) 个小写字母组成的串 \(S\),有 \(q\) 个询问,每次给你 \(l,r\),问你 \(S_{l\ldots r}\) 有多少个非空子序列。
\(m=9,n=\lvert S\rvert \leq {10}^5,q\leq {10}^5\)
题解
题解接下来的部分求得答案是包含空串的答案。
最简单的做法是DP。
设 \(f_{i,j}\) 为前 \(i\) 个字符,末尾为 \(j\) 的子序列个数。
特殊的,如果 \(j=m+1\) 就说明当前还没有选任何字符。
递推式为
答案为 \(\sum_{i=1}^{m+1}f_{n,i}\)
复杂度为 \(O(nq)\)。
注意到这个转移是一个矩阵乘法的形式,记
那么
这里 \(A_i\) 就是把单位矩阵的第 \(S_i\) 行全部设为一得到的矩阵。
记
那么我们最终的答案就是
我们可以维护一个 \(A\) 的前缀积和 \(A\) 的逆的前缀积即可。需要注意矩阵乘法的顺序。
时间复杂度:\(O(nm^3+qm^2)\)
可以发现,做矩阵乘法的时候只有一行有变化,那么预处理的矩阵乘法的复杂度可以降到 \(O(nm^2)\)。总复杂度为 \(O((n+q)m^2)\)
其实这道题还能进一步优化。
对于询问,我们只需要在预处理的时候把矩阵乘以 \(U\) 或 \(V\) 的结果保存下来,就可以做到 \(O(qm)\)。
对于预处理,求 \(A_rA_{r-1}\cdots A_1\) 的时候左乘转移矩阵的时候实际上是把 \(S_i\) 这一行中每个位置的值改为这一列所有元素的和,直接维护一下每列的和就好了。右乘转移矩阵的逆矩阵就是对于每一行,除了 \(S_i\) 这一列外其他所有列都减掉这个位置。那么可以维护一下这一列所有元素共同减掉的数,然后修改一下这个位置单点的值。
具体来说,假设原来的矩阵的某一行 \(0\) 是长这样:
对第三个位置操作后就会变成
这样预处理的复杂度就降到了 \(O(nm)\)
总复杂度为 \(O((n+q)m)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<functional>
#include<cmath>
#include<vector>
//using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const int p=1000000007;
const int N=100010;
const int M=10;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
const ll inv2=fp(2,p-2);
int plus(int a,int b)
{
a+=b;
return a>=p?a-p:a;
}
int plus2(int a)
{
a+=a;
return a>=p?a-p:a;
}
int minus(int a,int b)
{
a-=b;
return a<0?a+p:a;
}
int c[N];
int n,q;
int f1[N][10];
int a1[10][10];
int f2[N][10];
int a2[10][10];
char str[N];
void init()
{
for(int i=0;i<=9;i++)
a1[i][i]=a2[i][i]=f1[0][i]=1;
for(int i=1;i<=n;i++)
{
int v=str[i]-'a';
for(int j=0;j<=9;j++)
{
f1[i][j]=minus(plus2(f1[i-1][j]),a1[v][j]);
a1[v][j]=f1[i-1][j];
f2[i][j]=a2[v][j];
a2[v][j]=minus(plus2(a2[v][j]),f2[i-1][j]);
}
}
}
int main()
{
scanf("%s",str+1);
n=strlen(str+1);
scanf("%d",&q);
init();
int l,r;
for(int i=1;i<=q;i++)
{
l=rd();
r=rd();
int ans=f1[r][9]-1;
for(int j=0;j<=8;j++)
ans=(ans-(ll)f1[r][j]*f2[l-1][j])%p;
ans=plus(ans,p);
printf("%d\n",ans);
}
return 0;
}