矩阵快速幂优化DP
矩阵满足:
(1)结合律:
ABC=A(Bc)
(2)分配率
AB+AC=A(B+C)
(3)特殊交换律
单位矩阵(对角线)。
前缀平方求和:
f[n]=n(n+1)*(2n+1)/6
【GT考试】给出长度m的数字串,求长度为n的文本串个数,不出现给出的数字串。(n<=1e9,m<=20)
一看到如此极端的数据范围肯定就是矩阵快速幂了。首先是AC自动机的DP板子,《文本生成器》,求长度固定的串中不出现给出串的串个数。dp[i][j]:表示i指针状态下,长度是j的文本串个数。
for(int i=1;i<=m;++i)
for(int j=0;j<=tot;++j)
for(int k=0;k<=9;++k)
dp[i][ch[j][k]]+=dp[i-1][j])(if(!en[ch[j][k]]))
发现第一维度和第二三维度没有关系,我只需要统计每个\(dp[j]\)被其他\(dp[k]\)累加过多少次就行。我用bas[i][j]=k,代表\(dp[j]=dp[i]*k+...\)。对n次系数快速幂计算,然后用\(ans.w[1][1]=1(dp[0][0]=1)\)乘上bas,第一行的tot+1个结果累加就是答案。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=10000;
char sr[22];
int ch[30][11],n,m,mod,tot,fail[30],en[30];
deque<int>st;
struct node
{
int w[33][33];
inline void clear()
{
memset(w,0,sizeof(w));
}
}ans,bas;
node operator*(node A,node B)
{
node res;res.clear();
for(int i=1;i<=tot+1;i++)
for(int k=1;k<=tot+1;++k)
{
int rek=A.w[i][k];
for(int j=1;j<=tot+1;++j)
{
res.w[i][j]=(res.w[i][j]+rek*B.w[k][j])%mod;
// printf("(%d %d)%d*(%d %d)%d:res:(%d,%d):%d\n",i,k,A.w[i][k],k,j,B.w[k][j],i,j,res.w[i][j]);
}
}
return res;
}
inline void insert(char*s)
{
int len=strlen(s),now=0;
for(int i=0;i<len;++i)
{
int chr=s[i]-'0';
if(!ch[now][chr])ch[now][chr]=++tot;
now=ch[now][chr];
}
en[now]=1;
}
inline void build()
{
for(int i=0;i<10;++i)
if(ch[0][i])st.push_back(ch[0][i]);
while(!st.empty())
{
int top=st.front();st.pop_front();
for(int i=0;i<10;++i)
{
if(ch[top][i])fail[ch[top][i]]=ch[fail[top]][i],st.push_back(ch[top][i]),en[ch[top][i]]|=en[ch[fail[top]][i]];
else ch[top][i]=ch[fail[top]][i];
}
}
}
inline node qpow(node ar,int p)
{
node cod;cod.clear();int ct=p;
for(int i=1;i<=tot+1;++i)cod.w[i][i]=1;//矩阵初始化对角线
while(ct)
{
if(ct&1)cod=cod*ar;
ct>>=1;
ar=ar*ar;
}
return cod;
}
int main()
{
scanf("%d%d%d",&n,&m,&mod);
scanf("%s",sr);insert(sr);
build();
// printf("%d\n",tot);
for(int i=0;i<=tot;++i)
for(int j=0;j<10;++j)
if(!en[ch[i][j]])
bas.w[i+1][ch[i][j]+1]++;
// for(int i=1;i<=tot+1;++i)
// {
// for(int j=1;j<=tot+1;j++)
// printf("%d ",bas.w[i][j]);
// printf("\n");
// }
// printf("\n\n");
bas=qpow(bas,n);
// for(int i=1;i<=tot+1;++i)
// {
// for(int j=1;j<=tot+1;j++)
// printf("%d ",bas.w[i][j]);
// printf("\n");
// }
int sum=0;
ans.w[1][1]=1;
ans=ans*bas;
for(int i=1;i<=tot+1;i++)
sum=(sum+ans.w[1][i])%mod;
printf("%d",sum);
return 0;
}
/*
4 3 100
111
*/
对于suma[i]和sumb[i]展开优化,发现suma[i]和sumb[i]可以直接从suma[i-1]和sumb[i-1]直接推出,所以在我们O(n)枚举断点的情况下
(前面可以取0,后面不可以),O(1)利用矩阵乘法就可以求出系数和答案。
假设断点是i,状态矩阵[sai,sbi]Bas-->[sa(i+1),sb(i+1)],需要22矩阵系数转移,要注意顺序,就是
由Aibas1=A(i+1),A(i+1)bas2=A(i+2)....那么求矩阵的前缀就应该是fro[i]=fro[i-1]fro[i],我们已经求出的和在前面。
求后缀,就这样想,我已经有了bas(i+1)bas(i+2)bas(i+3)...bas(n)的答案矩阵,我要把bas(i)塞进去,它的实际递推式子应该是
bas(i)bas(i+1)bas(i+2)bas(i+3)...bas(n)=ans(i),所以后缀就是bac[i]=bac[i]bac[i+1]。
当然,如果你是BasA[i]=A[i+1],那么就倒着想,ansi=bas(i)bas(i-1)bas(i-2)...*bas(1),也可以。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define _f(i,a,b) for(register int i=a;i<=b;++i)
#define f_(i,a,b) for(register int i=a;i>=b;--i)
#define chu printf
#define ll long long
#define ull unsigned long long
inline ll re()
{
ll x=0,h=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')h=-1;
ch=getchar();
}
while(ch<='9'&&ch>='0')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*h;
}
const ll mod=1e9+7,inv2=500000004,inv6=166666668;
// inline ll S2(ll a)
// {
// return a*(a+1)%mod*(2*a+1)%mod*inv6;
// }
int n;
ll a[1000000+100];
struct Node
{
ll a1,a2,b1,b2;
//Node(){a1=a2=b1=b2=0;}
Node operator*(const Node&U)const
{
Node res;
res.a1=(a1*U.a1%mod+a2*U.b1%mod)%mod;
res.a2=(a1*U.a2%mod+a2*U.b2%mod)%mod;
res.b1=(b1*U.a1%mod+b2*U.b1%mod)%mod;
res.b2=(b1*U.a2%mod+b2*U.b2%mod)%mod;
return res;
}
}fro[1000000+10],bac[1000000+10],ans;
int main()
{
freopen("y.in","r",stdin);
freopen("y.out","w",stdout);
n=re();
fro[0]=(Node){1,0,0,1};
bac[n+1]=(Node){1,0,0,1};
_f(i,1,n)a[i]=re();
_f(i,1,n)
fro[i]=(Node){inv2*a[i]%mod*(a[i]+1)%mod,a[i]*(a[i]+1)%mod*(a[i]-1)%mod*inv6%mod,a[i]+1,inv2*a[i]%mod*(a[i]+1)%mod};
_f(i,1,n)
bac[i]=(Node){inv2*a[i]%mod*(a[i]-1)%mod,a[i]*(a[i]+1)%mod*(a[i]-1)%mod*inv6%mod,a[i],inv2*a[i]%mod*(a[i]+1)%mod};
_f(i,1,n)
fro[i]=fro[i-1]*fro[i];
f_(i,n,1)
bac[i]=bac[i]*bac[i+1];
ll sum=0;
_f(i,1,n)
{
ans=bac[i+1]*fro[i-1];
sum=(sum+ans.a1*a[i]%mod+ans.a2)%mod;
}
chu("%lld",sum);
return 0;
}
/*
5
2 2 2 0 2
10
0 1 2 1 2 2 1 0 1 2
3 1 1 1
3 2 1 1
10 1 1 1 1 1 1 1 1 1 1
*/