BZOJ 4128 Matrix BSGS+矩阵求逆

题意:链接

方法: BSGS+矩阵求逆

解析:

这题就是把Ax=B(mod C)的A和B换成了矩阵。

然而别的地方并没有修改。

所以就涉及到矩阵的逆元这个问题。

矩阵的逆元怎么求呢?

先在原矩阵后接一个单位矩阵,最好还是设右对角线

先把原矩阵进行高斯消元

且消成严格右对角线的单位矩阵的形式。

然后在消元的同一时候把单位矩阵的部分一并计算。最后单位矩阵就变成了它的逆矩阵。

这道题保证矩阵有逆

然而没有逆矩阵的情况就是高斯消元搞不成。

所以推断应该也好推断。

另外,刚刚实測本题数据。关于将矩阵的hash,直接取右下角的值即可了。太弱了数据

代码:

#include <map>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 75
#define M 140345
#define base 0
using namespace std;
int n,p;
struct Matrix
{
    int map[N][N];
}A,B,ret;
struct node
{
    int from,to,next;
    int val;
}edge[M+10];
int cnt,head[M+10];
void init()
{
    memset(head,-1,sizeof(head));
    cnt=1;
}
void edgeadd(int from,int to,int val)
{
    edge[cnt].from=from,edge[cnt].to=to,edge[cnt].val=val;
    edge[cnt].next=head[from];
    head[from]=cnt++;
}
Matrix mul(Matrix a,Matrix b)
{
    memset(ret.map,0,sizeof(ret.map));
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            for(int k=1;k<=n;k++)
                ret.map[i][j]=(ret.map[i][j]+a.map[i][k]*b.map[k][j])%p;
    return ret;
}
int quick_my(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ret*x)%p;
        x=(x*x)%p;
        y>>=1;
    }
    return ret;
}
int hash(Matrix x)
{
    return x.map[n][n];
}
Matrix get_inv(Matrix x)
{
    memset(ret.map,0,sizeof(ret.map));
    for(int i=1;i<=n;i++)ret.map[i][i]=1;
    for(int i=1;i<=n;i++)
    {
        int chose=-1;
        for(int j=i;j<=n;j++)
            if(x.map[j][i]!=0){chose=j;break;}
        //if(chose==-1)
        //  return -1
        //无解我脑补应该是这样吧。
        for(int j=1;j<=n;j++)
            swap(x.map[i][j],x.map[chose][j]),swap(ret.map[i][j],ret.map[chose][j]);
        int inv=quick_my(x.map[i][i],p-2);
        for(int j=1;j<=n;j++)
        { 
            x.map[i][j]=x.map[i][j]*inv%p;
            ret.map[i][j]=ret.map[i][j]*inv%p; 
        } 
        for(int j=1;j<=n;j++)
        {
            if(i==j)continue;
            int pre=x.map[j][i];//一定要提前记录不然肯定会影响答案。由于这个值被改变=-=我也是脑残了。
            for(int k=1;k<=n;k++)
            {
                x.map[j][k]=((x.map[j][k]-pre*x.map[i][k])%p+p)%p;
                ret.map[j][k]=((ret.map[j][k]-pre*ret.map[i][k])%p+p)%p;
            }
        }
    }
    return ret; 
}
int BSGS(Matrix A,Matrix B,int C)
{
    init();
    int m=(int)ceil(sqrt(C));
    Matrix tmp;
    memset(tmp.map,0,sizeof(tmp.map));
    for(int i=1;i<=n;i++)tmp.map[i][i]=1;
    for(int i=0;i<m;i++)
    {
        int hashtmp=hash(tmp);
        int flag=1;
        for(int j=head[hashtmp%M];j!=-1;j=edge[j].next)
            if(edge[j].val==hashtmp){flag=0;break;}
        if(flag)edgeadd(hashtmp%M,i,hashtmp);
        tmp=mul(tmp,A);
    }
    Matrix inv=get_inv(tmp);
    for(int i=0;i<=m;i++)
    {
        int hashB=hash(B);
        for(int j=head[hashB%M];j!=-1;j=edge[j].next)
            if(edge[j].val==hashB)return i*m+edge[j].to;
        B=mul(B,inv);
    }
    return -1;
}
int main()
{
    scanf("%d%d",&n,&p);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%d",&A.map[i][j]);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%d",&B.map[i][j]);
    printf("%d\n",BSGS(A,B,p));
}
posted @ 2017-07-27 09:39  yfceshi  阅读(321)  评论(0编辑  收藏  举报