cf1106F Lunar New Year and a Recursive Sequence - 矩阵快速幂 + 离散对数BSGS / N次剩余 + 欧拉降幂

传送门

根据公式可以退出\(i>k\)时的\(f_i = f_i^{g_i}\), 其中\(g_i\)是存在一个递推式\(g_i = \sum_{j = 1}^kb_j \times g_{i - j}\)
已知\(k \in [1,100]\),那么就可以用矩阵快速幂求出\(f_n\)的幂\(k\),那么

\(f_n = f_k^k \equiv m(\mod 998244353 )\)

相当于变成了\(x^a\equiv m(\mod 998244353 )\),其中\(a, m\)已知。

那不就是N次剩余吗?试了下洛谷的板子,mle了。。。

其实这里可以进行优化,对于N次剩余来说,如果模存在离散对数,那么可以进行化简

\(a\times ind_g x \equiv ind_g m(\mod P - 1)\)

其中\(x = g^{ind_g x + P - 1 } \equiv (\mod P)\)

其中g就是模\(P\)的一个原根。
\(ind_g m\)\(m\)在模的一个原根\(g\)下的离散对数
利用同余方程就可以解出\(ind_g x\)的值,最后再代换回来就行了

注意一点就是用矩阵快速幂求幂,且模是质数,那么幂就去欧拉降幂,模\(P - 1\)

#include <bits/stdc++.h>
#define ll long long
#define CASE int Kase = 0; cin >> Kase; for(int kase = 1; kase <= Kase; kase++)
using namespace std;
template<typename T = long long> inline T read() {
    T s = 0, f = 1; char ch = getchar();
    while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) {s = (s << 3) + (s << 1) + ch - 48; ch = getchar();} 
    return s * f;
}
const int N = 100 + 5, M = 100 + 5, MOD = 1e9 + 7, CM = 998244353, INF = 0x3f3f3f3f, g = 3;
struct Matrix{
    int n, m;
    int a[M][M];
    Matrix(int n = 0, int m = 0):n(n),m(m){memset(a, 0, sizeof(a));};
    Matrix operator * (const Matrix &b) const {
        Matrix ans(n,b.m);
        for(int i = 0; i < n; i++)
            for(int j = 0; j < b.m; j++)
                for(int k = 0; k < m; k++)
                    ans.a[i][j] = (ans.a[i][j] + 1ll * a[i][k] * b.a[k][j] % (CM - 1)) % (CM - 1);
        return ans;
    }
};
Matrix ksm(Matrix a, int b){
    Matrix ans(a.n, a.m);
    for(int i = 0; i < max(a.n, a.m); i++) ans.a[i][i] = 1;
    while(b) {
        if(b & 1) ans = ans * a;
        a = a * a;
        b >>= 1;
    }
    return ans;
}
int b[N];
ll cal(int n, int k){
    Matrix base(k, k);
    for(int i = 0; i < k; i++) base.a[0][i] = b[i + 1];
    for(int i = 1; i < k; i++) base.a[i][i - 1] = 1;
    Matrix ans(k, 1);
    ans.a[0][0] = 1;
    base = ksm(base, n - k);
    ans = base * ans;
    return ans.a[0][0]; 
}
ll qpow(ll a, ll b, ll p){
    ll ans = 1; a %= p;
    while(b){
        if(b & 1) ans = ans * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return ans;
}
map<int,int>mp;
int bsgs(int x, int P){
    int m = sqrt(P) + 1; mp.clear();
    for(int i = 0, res = x; i < m; ++i, res = 1ll * res * g % P) mp[res] = i;
    for(int i = 1, tmp = qpow(g, m, P), res = tmp; i <= m + 1; ++i, res = 1ll * res * tmp % P)
        if(mp.count(res)) return i * m - mp[res];
    return 0;
}
void ex_gcd(ll a, ll b, ll &d, ll &x, ll &y){
    if(!b){
        d = a, x = 1, y = 0;
        return;
    }
    ex_gcd(b,a % b,d,y,x);
    y -= x * (a / b);
}
ll gcd(ll a, ll b){
    return b == 0 ? a : gcd(b, a % b);
}
void cal(ll a, ll &x, ll b, ll m){ // 求ax == b (mod m) 的解x
    ll d = gcd(a, m), y = 0;
    if(b % d) return x = -1, void(0);
    ex_gcd(a, m, d, x, y);
    x = (b / d * x + m / d) % (m / d); // 最小正整数解
}
void solve(int kase){
    int k = read();
    for(int i = 1; i <= k; i++) b[i] = read();
    int n = read(), m = read();
    ll qpower = cal(n, k);
    ll ans = 0;
    cal(qpower, ans, bsgs(m, CM), CM - 1);
    if(ans == -1) printf("-1\n");
    else printf("%lld\n", qpow(g, ans + CM - 1, CM));
}
const bool DUO = 0;
int main(){
    clock_t start, finish; double totaltime; start = clock();
    if(DUO) {CASE solve(kase);} else solve(1);
    finish = clock(); 
    #ifdef ONLINE_JUDGE
        return 0;
    #endif
    printf("\nTime: %lfms\n", (double)(finish - start) / CLOCKS_PER_SEC * 1000);
    return 0;
}
posted @ 2021-02-07 19:02  Emcikem  阅读(87)  评论(0编辑  收藏  举报