矩阵快速幂优化DP

矩阵满足:
(1)结合律:
ABC=A(Bc)
(2)分配率
AB+AC=A(B+C)
(3)特殊交换律
单位矩阵(对角线)。
前缀平方求和:
f[n]=n
(n+1)*(2n+1)/6

【GT考试】给出长度m的数字串,求长度为n的文本串个数,不出现给出的数字串。(n<=1e9,m<=20)

一看到如此极端的数据范围肯定就是矩阵快速幂了。首先是AC自动机的DP板子,《文本生成器》,求长度固定的串中不出现给出串的串个数。dp[i][j]:表示i指针状态下,长度是j的文本串个数。

for(int i=1;i<=m;++i)
for(int j=0;j<=tot;++j)
for(int k=0;k<=9;++k)
dp[i][ch[j][k]]+=dp[i-1][j])(if(!en[ch[j][k]]))

发现第一维度和第二三维度没有关系,我只需要统计每个\(dp[j]\)被其他\(dp[k]\)累加过多少次就行。我用bas[i][j]=k,代表\(dp[j]=dp[i]*k+...\)。对n次系数快速幂计算,然后用\(ans.w[1][1]=1(dp[0][0]=1)\)乘上bas,第一行的tot+1个结果累加就是答案。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=10000;
char sr[22];
int ch[30][11],n,m,mod,tot,fail[30],en[30];
deque<int>st;
struct node
{
	int w[33][33];
	inline void clear()
	{
		memset(w,0,sizeof(w));
	}
}ans,bas;
node operator*(node A,node B)
{
	node res;res.clear();
	for(int i=1;i<=tot+1;i++)
	for(int k=1;k<=tot+1;++k)
	{
		int rek=A.w[i][k];
	for(int j=1;j<=tot+1;++j)
	{
		res.w[i][j]=(res.w[i][j]+rek*B.w[k][j])%mod;
	//	printf("(%d %d)%d*(%d %d)%d:res:(%d,%d):%d\n",i,k,A.w[i][k],k,j,B.w[k][j],i,j,res.w[i][j]);
	}
	}

	return res;	
}
inline void insert(char*s)
{
	int len=strlen(s),now=0;
	for(int i=0;i<len;++i)
	{
		int chr=s[i]-'0';
		if(!ch[now][chr])ch[now][chr]=++tot;
		now=ch[now][chr];
	}
	en[now]=1;
}
inline void build()
{
	for(int i=0;i<10;++i)
	if(ch[0][i])st.push_back(ch[0][i]);
	while(!st.empty())
	{
		int top=st.front();st.pop_front();
		for(int i=0;i<10;++i)
		{
			if(ch[top][i])fail[ch[top][i]]=ch[fail[top]][i],st.push_back(ch[top][i]),en[ch[top][i]]|=en[ch[fail[top]][i]];
			else ch[top][i]=ch[fail[top]][i];
		}
	}
}
inline node qpow(node ar,int p)
{
	node cod;cod.clear();int ct=p;
	for(int i=1;i<=tot+1;++i)cod.w[i][i]=1;//矩阵初始化对角线 
	while(ct)
	{
		if(ct&1)cod=cod*ar;
		ct>>=1;

		ar=ar*ar;
	
	}
	return cod;
}
int main()
{
	scanf("%d%d%d",&n,&m,&mod);
	scanf("%s",sr);insert(sr);
	build();
//	printf("%d\n",tot);
	for(int i=0;i<=tot;++i)
	for(int j=0;j<10;++j)
	if(!en[ch[i][j]])
	bas.w[i+1][ch[i][j]+1]++;
//	for(int i=1;i<=tot+1;++i)
//	{
//		for(int j=1;j<=tot+1;j++)
//		printf("%d ",bas.w[i][j]);		
//		printf("\n");
//	}
//	printf("\n\n");
	bas=qpow(bas,n);
//	for(int i=1;i<=tot+1;++i)
//	{
//		for(int j=1;j<=tot+1;j++)
//		printf("%d ",bas.w[i][j]);		
//		printf("\n");
//	}
	int sum=0;
	ans.w[1][1]=1;
	ans=ans*bas;

	for(int i=1;i<=tot+1;i++)
	sum=(sum+ans.w[1][i])%mod;
	printf("%d",sum);
	return 0;
}
/*
4 3 100 
111
*/
>###需要非常注意转移顺序 ###![image](https://img2022.cnblogs.com/blog/2761046/202209/2761046-20220915122046441-1234132826.png)

image

对于suma[i]和sumb[i]展开优化,发现suma[i]和sumb[i]可以直接从suma[i-1]和sumb[i-1]直接推出,所以在我们O(n)枚举断点的情况下
(前面可以取0,后面不可以),O(1)利用矩阵乘法就可以求出系数和答案。
假设断点是i,状态矩阵[sai,sbi]Bas-->[sa(i+1),sb(i+1)],需要22矩阵系数转移,要注意顺序,就是
由Aibas1=A(i+1),A(i+1)bas2=A(i+2)....那么求矩阵的前缀就应该是fro[i]=fro[i-1]fro[i],我们已经求出的和在前面。
求后缀,就这样想,我已经有了bas(i+1)
bas(i+2)bas(i+3)...bas(n)的答案矩阵,我要把bas(i)塞进去,它的实际递推式子应该是
bas(i)bas(i+1)bas(i+2)bas(i+3)...bas(n)=ans(i),所以后缀就是bac[i]=bac[i]bac[i+1]。
当然,如果你是Bas
A[i]=A[i+1],那么就倒着想,ansi=bas(i)bas(i-1)bas(i-2)...*bas(1),也可以。

点击查看代码










#include<bits/stdc++.h>
using namespace std;
#define _f(i,a,b) for(register int i=a;i<=b;++i)
#define f_(i,a,b) for(register int i=a;i>=b;--i)
#define chu printf
#define ll long long
#define ull unsigned long long
inline ll re()
{
    ll x=0,h=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-')h=-1;
        ch=getchar();
    }
    while(ch<='9'&&ch>='0')
    {
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*h;
}
const ll mod=1e9+7,inv2=500000004,inv6=166666668;
// inline ll S2(ll a)
// {
//     return a*(a+1)%mod*(2*a+1)%mod*inv6;
// }
int n;
ll a[1000000+100];
struct Node
{
    ll a1,a2,b1,b2;
    //Node(){a1=a2=b1=b2=0;}
    Node operator*(const Node&U)const
    {
        Node res;
        res.a1=(a1*U.a1%mod+a2*U.b1%mod)%mod;
        res.a2=(a1*U.a2%mod+a2*U.b2%mod)%mod;
        res.b1=(b1*U.a1%mod+b2*U.b1%mod)%mod;
        res.b2=(b1*U.a2%mod+b2*U.b2%mod)%mod;
        return res;
    }
}fro[1000000+10],bac[1000000+10],ans;
int main()
{
 freopen("y.in","r",stdin);
 freopen("y.out","w",stdout);
    n=re();
    fro[0]=(Node){1,0,0,1};
    bac[n+1]=(Node){1,0,0,1};
    _f(i,1,n)a[i]=re();
    _f(i,1,n)
    fro[i]=(Node){inv2*a[i]%mod*(a[i]+1)%mod,a[i]*(a[i]+1)%mod*(a[i]-1)%mod*inv6%mod,a[i]+1,inv2*a[i]%mod*(a[i]+1)%mod};
    _f(i,1,n)
    bac[i]=(Node){inv2*a[i]%mod*(a[i]-1)%mod,a[i]*(a[i]+1)%mod*(a[i]-1)%mod*inv6%mod,a[i],inv2*a[i]%mod*(a[i]+1)%mod};
    _f(i,1,n)
    fro[i]=fro[i-1]*fro[i];
    f_(i,n,1)
    bac[i]=bac[i]*bac[i+1];
    ll sum=0;
    _f(i,1,n)
    {
        ans=bac[i+1]*fro[i-1];
        sum=(sum+ans.a1*a[i]%mod+ans.a2)%mod;
    }
    chu("%lld",sum);
    return 0;
}
/*
5
2 2 2 0 2 
10
0 1 2 1 2 2 1 0 1 2 
3 1 1 1
3 2 1 1 
10 1 1 1 1 1 1 1 1 1 1 
*/

posted on 2022-08-25 14:59  HZOI-曹蓉  阅读(82)  评论(0编辑  收藏  举报