矩阵求逆
BZOJ4128
矩阵求逆+BSGS
在原矩阵旁附加单位矩阵同时消元,消元时对所有行消元将原矩阵消成单位矩阵即可
#include <cstdio> #include <iostream> #include <map> #include <cmath> #define LL long long using namespace std; map <LL,int> mpa; map <LL,int> mpb; int mo,n; struct matrix{ int n,m; int a[201][201]; void mul(matrix &b){ int tmp[201][201]; for (int i=0;i<=n;i++) for (int j=0;j<=b.m;j++) tmp[i][j]=0; for (int i=0;i<=n;i++) for (int k=0;k<=m;k++) if (a[i][k]) for (int j=0;j<=b.m;j++) tmp[i][j]+=a[i][k]*b.a[k][j],tmp[i][j]%=mo; for (int i=0;i<=n;i++) for (int j=0;j<=b.m;j++) a[i][j]=tmp[i][j]; } void zero(){ for (int i=1;i<=n;i++) for (int j=1;j<=m;j++) a[i][j]=(i==j); } void cpy(matrix &b){ n=m=b.n; for (int i=1;i<=n;i++) for (int j=1;j<=m;j++) a[i][j]=b.a[i][j]; } }a,b,tmp,bas; int qpow(int bas,int powe){ int ret=1; for (;powe;bas*=bas,bas%=mo){ if (powe&1) ret*=bas,ret%=mo; powe=powe>>1; } return(ret); } LL gethash(matrix &a){ LL ret=0; for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) ret*=233,ret+=a.a[i][j]; return(ret); } void hashin(matrix &a,int t){ LL num=gethash(a); if (mpa[num]) mpb[num]=min(mpb[num],t);else{ mpa[num]=1;mpb[num]=t; } } int query(matrix &a){ static matrix tp; tp.cpy(b); tp.mul(a); LL num=gethash(tp); if (!mpa[num]) return(1e9);else return(mpb[num]); } void inv(matrix &a){ static int tmp[201][201]; for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) tmp[i][j]=a.a[i][j],tmp[i][j+n]=(i==j); for (int i=1;i<=n;i++){ LL t=qpow(tmp[i][i],mo-2); for (int j=i;j<=2*n;j++) tmp[i][j]=tmp[i][j]*t%mo; for (int j=1;j<=n;j++) if (tmp[j][i]&&i!=j){ t=tmp[j][i]; for (int k=i;k<=2*n;k++){ tmp[j][k]-=t*tmp[i][k]; tmp[j][k]%=mo;tmp[j][k]+=mo;tmp[j][k]%=mo; } } } for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) a.a[i][j]=tmp[i][j+n]; } int main(){ scanf("%d%d",&n,&mo); for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) scanf("%d",&a.a[i][j]); a.n=a.m=n; for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) scanf("%d",&b.a[i][j]); b.n=b.m=n; tmp.n=tmp.m=n; int blsiz=sqrt(mo)+1; bas.cpy(a);tmp.zero(); for (int i=1;i<=blsiz;i++) tmp.mul(bas); bas.cpy(tmp); tmp.zero(); hashin(tmp,0); for (int i=1;i<=blsiz;i++){ tmp.mul(bas); hashin(tmp,i*blsiz); } int ans=1e9; inv(a); tmp.zero(); ans=min(ans,query(tmp)); for (int i=1;i<=blsiz;i++){ tmp.mul(a); ans=min(ans,query(tmp)+i); } printf("%d\n",ans); }