矩阵求逆

BZOJ4128

矩阵求逆+BSGS

在原矩阵旁附加单位矩阵同时消元,消元时对所有行消元将原矩阵消成单位矩阵即可

#include <cstdio>
#include <iostream>
#include <map>
#include <cmath>
#define LL long long 
using namespace std;

    map <LL,int> mpa;
    map <LL,int> mpb;

    int mo,n;

    struct matrix{
    int n,m;
    int a[201][201];
      
    void mul(matrix &b){
      int tmp[201][201];
      for (int i=0;i<=n;i++) 
        for (int j=0;j<=b.m;j++) 
          tmp[i][j]=0;
        
      for (int i=0;i<=n;i++)
        for (int k=0;k<=m;k++)
          if (a[i][k]) 
            for (int j=0;j<=b.m;j++)
              tmp[i][j]+=a[i][k]*b.a[k][j],tmp[i][j]%=mo;
        
      for (int i=0;i<=n;i++)
        for (int j=0;j<=b.m;j++)
          a[i][j]=tmp[i][j];
    }
    
    void zero(){
      for (int i=1;i<=n;i++) for (int j=1;j<=m;j++) a[i][j]=(i==j);
    }
    
    void cpy(matrix &b){
      n=m=b.n;
      for (int i=1;i<=n;i++) for (int j=1;j<=m;j++) a[i][j]=b.a[i][j];
    }
  }a,b,tmp,bas;

  int qpow(int bas,int powe){
      int ret=1;
      for (;powe;bas*=bas,bas%=mo){
        if (powe&1) ret*=bas,ret%=mo;
      powe=powe>>1;    
    }
    return(ret);
  }

  LL gethash(matrix &a){
      LL ret=0;
      for (int i=1;i<=n;i++) for (int j=1;j<=n;j++)
        ret*=233,ret+=a.a[i][j];
      return(ret);
  }

  void hashin(matrix &a,int t){
      LL num=gethash(a);
      if (mpa[num]) mpb[num]=min(mpb[num],t);else{
        mpa[num]=1;mpb[num]=t;
    }
  }
  
  int query(matrix &a){
      static matrix tp;
      tp.cpy(b);
    tp.mul(a);
      LL num=gethash(tp);
      if (!mpa[num]) return(1e9);else return(mpb[num]);
  }
  
  void inv(matrix &a){
      static int tmp[201][201];
      for (int i=1;i<=n;i++)
        for (int j=1;j<=n;j++)
          tmp[i][j]=a.a[i][j],tmp[i][j+n]=(i==j);
      for (int i=1;i<=n;i++){
        LL t=qpow(tmp[i][i],mo-2);
         for (int j=i;j<=2*n;j++) tmp[i][j]=tmp[i][j]*t%mo;
      for (int j=1;j<=n;j++) if (tmp[j][i]&&i!=j){
          t=tmp[j][i];
          for (int k=i;k<=2*n;k++){
            tmp[j][k]-=t*tmp[i][k];
          tmp[j][k]%=mo;tmp[j][k]+=mo;tmp[j][k]%=mo;    
        }
      }            
    }
    
    for (int i=1;i<=n;i++) for (int j=1;j<=n;j++)
      a.a[i][j]=tmp[i][j+n];    
  }

  int main(){      
      scanf("%d%d",&n,&mo);
      for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) scanf("%d",&a.a[i][j]);
      a.n=a.m=n;
      for (int i=1;i<=n;i++) for (int j=1;j<=n;j++) scanf("%d",&b.a[i][j]);
      b.n=b.m=n;
      tmp.n=tmp.m=n;
      
      int blsiz=sqrt(mo)+1;
      bas.cpy(a);tmp.zero();
      for (int i=1;i<=blsiz;i++) tmp.mul(bas);
      bas.cpy(tmp);
    tmp.zero();
    hashin(tmp,0);
      for (int i=1;i<=blsiz;i++){
        tmp.mul(bas);
      hashin(tmp,i*blsiz);    
    }
    
    int ans=1e9;
    inv(a);
    tmp.zero();
    ans=min(ans,query(tmp));
    for (int i=1;i<=blsiz;i++){
      tmp.mul(a);
      ans=min(ans,query(tmp)+i);
    }
    
    printf("%d\n",ans);
  }

 

posted @ 2017-02-22 19:29  z1j1n1  阅读(253)  评论(0编辑  收藏  举报