BZOJ4128: Matrix(BSGS 矩阵乘法)

Time Limit: 10 Sec  Memory Limit: 128 MB
Submit: 813  Solved: 442
[Submit][Status][Discuss]

Description

给定矩阵A,B和模数p,求最小的x满足

A^x = B (mod p)

 

Input

第一行两个整数n和p,表示矩阵的阶和模数,接下来一个n * n的矩阵A.接下来一个n * n的矩阵B

 

Output

输出一个正整数,表示最小的可能的x,数据保证在p内有解

 

Sample Input

2 7
1 1
1 0
5 3
3 2

Sample Output

4

HINT

 

对于100%的数据,n <= 70,p <=19997,p为质数,0<= A_{ij},B_{ij}< p

保证A有逆
 

Source

裸的BSGS,把$x$分解为$im - j$

原式化为$a^{im} \equiv ba^j \pmod p$

其中$m = \ceil{sqrt(p)}$

然后枚举一个$j$,存到map里

再枚举一个$i$判断即可

一开始map写成bool类型了调了半个小时

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<map>
//#define LL long long  
using namespace std;
const int MAXN = 4 * 1e5 + 10;
inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}
int N, mod, M;
struct Matrix {
    int m[71][71];
    Matrix operator * (const Matrix &rhs) const {
        Matrix ans = {};
        for(int i = 1; i <= N; i++)    
            for(int j = 1; j <= N; j++)
                for(int k = 1; k <= N; k++)
                    (ans.m[i][j] += m[i][k] * rhs.m[k][j]) %= mod;
        return ans; 
    }
    void init() {
        for(int i = 1; i <= N; i++)
            for(int j = 1; j <= N; j++)    
                m[i][j] = read();
    }
    void print() {
        for(int i = 1; i <= N; i++, puts(""))
            for(int j = 1; j <= N; j++)
                printf("%d ", m[i][j]);
    }
    bool operator < (const Matrix &rhs) const {
        for(int i = 1; i <= 70; i++)
            for(int j = 1; j <= 70; j++) {
                if(m[i][j] < rhs.m[i][j]) return 1;
                if(m[i][j] > rhs.m[i][j]) return 0;
            }
        return 0;
    }
}A, B;
map<Matrix, int> mp;
void MakeMap() {
    Matrix a = B;
    mp[a] = 0;
    for(int i = 1; i <= M; i++) a = a * A, mp[a] = i;
}
void FindAns() {
    Matrix a, am = A;
    for(int i = 1; i <= M - 1; i++) am = am * A;
    a = am;
    for(int i = 1; i <= M; i++) {
        if(mp[a]) printf("%d", i * M - mp[a]), exit(0);
        a = a * am;
    }
}
main() {
    N = read(); mod = read();
    A.init(); B.init();
    M = (double)ceil(sqrt(mod));
    MakeMap();
    FindAns();
}

 

posted @ 2018-07-10 17:54  自为风月马前卒  阅读(560)  评论(3编辑  收藏  举报

Contact with me