矩阵乘法与邻接矩阵
问题
设现在有一个邻接矩阵 \(A\) ,那么 \(A^p\) 表示什么。
相信大家都搜了一下,发现 \(A[i][j]\) 表示从 \(i\) 到 \(j\) 经过了 \(p\) 步的总方案数,但是原理却不一定明白,所以这篇文章主要想从另一个思路证明上面的结论。
证明
我们先抛开矩阵乘法,考虑单纯用 \(DP\) 这道题目怎样做。
我们可以设 \(DP[i][j][p]\) 表示从 \(i\) 点出发,经过了 \(p\) 步到达 \(j\) 点的方案数。
初始化因为 \(p=0\) 的初始化没有什么意义,所以我们直接看 \(p=1\) 的初始化。
当 \(p=1\) 的时候的初始化其实就是这张图的邻接矩阵,因为邻接矩阵就是原图,原图中间的两个点就是通过一条边相连。
我们考虑转移
\[DP[i][j][p]=\sum_{k=1}^n DP[k][j][p-1]*DP[i][k][1]
\]
其实就是我们考虑枚举全部的点在 \(p-1\) 的方案数,然后乘 ,
因为第三维只是和 \(p-1\) 有关,所以第三维可以省掉,然后和朴素的矩阵幂比较,我们发现他们长的一模一样,其实都是在求邻接矩阵的 \(p\) 次方,所以我们就可以使用矩阵快速幂来直接求最终的矩阵了,上面的问题也得到了证明。
典型例题
P3758 [TJOI2017]可乐
这道题目这篇博客已经讲得很明白了,不想再抄一遍了。emmm
#include<bits/stdc++.h>
#define ll long long
struct Mat{
int size;
ll **M=NULL;
inline ll Start()
{
if (M!=NULL) return M[1][1];
else return LLONG_MIN;
}
inline void Clear(int sz)
{
if (M==NULL) {New(sz);return ;}
for (int i=0;i<=sz;i++)
for (int j=0;j<=sz;j++)
M[i][j]=0;
return ;
}
inline void New(int sz)
{
if (M!=NULL)
{
printf("\nRE\n");
printf("This matrix has been used!\n");
return ;
}
size=sz;
M=new ll*[sz+10];
for (int i=0;i<sz+10;i++)
M[i]=new ll[sz+10];
Clear(sz);
return ;
}
inline void Build(int sz)
{
size=sz;
if (M==NULL) New(sz);
for (int i=1;i<=sz;i++) M[i][i]=1;
return ;
}
inline void Init(ll now[],int sz)
{
if (M==NULL) New(sz);
int num=0;
for (int i=1;i<=sz;i++)
for (int j=1;j<=sz;j++)
M[i][j]=now[++num];
return ;
}
inline void Out()
{
if (M!=NULL)
for (int i=1;i<=size;i++)
{
for (int j=1;j<=size;j++)
printf("%lld ",M[i][j]);
printf("\n");
}
return ;
}
inline void Delete()
{
for (int i=0;i<size+10;i++)
delete []M[i];
delete []M;
M=NULL;
size=0;
return ;
}
};
ll mod;
inline Mat operator * (Mat a,Mat b)
{
Mat c;
c.Clear(a.size);
for (int i=1;i<=c.size;i++)
for (int j=1;j<=c.size;j++)
for (int k=1;k<=c.size;k++)
c.M[i][j]=(c.M[i][j]+a.M[i][k]*b.M[k][j]%mod)%mod;
return c;
}
inline Mat Mat_qpow(Mat a,ll p)
{
Mat ans,base;
base.New(a.size);
ans.Build(a.size);
for (base=a;p;p>>=1,base=base*base)
if (p&1) ans=ans*base;
return ans;
}
int main()
{
mod=2017;
int n,m,k;
scanf("%d%d",&n,&m);
Mat now,ans;
now.New(n+1);ans.New(n+1);
for (int i=1;i<=m;i++)
{
int s,e;
scanf("%d%d",&s,&e);
now.M[s][e]=now.M[e][s]=1;
}
scanf("%d",&k);
for (int i=1;i<=n;i++)
{
now.M[i][n+1]=1;
now.M[i][i]=1;
}
now.M[n+1][n+1]=1;
ans=Mat_qpow(now,k);
ll Ans=0;
for (int i=1;i<=n+1;i++)
Ans=(Ans+ans.M[1][i])%mod;
printf("%lld\n",Ans);
ans.Delete();now.Delete();
return 0;
}
上面的证明有一些问题,以后再补锅