基于CRT加速的求解问题

引入&面向的问题

引入通常都比较抽象,可以不读qwq

对于一类形如给出一个数x并给出一个表达式使x做为表达式的第一项进行运算并输出结果取模P的这类问题。

显然可以通过线段树之类的来维护对于每类x取模P的情况下表达式的值。然而在P较大时,这种做法就变成了不优雅的\(O(Pnlog_2n)\)

对于这类问题就需要用到CRT。不会CRT先学CRT

题意

有一个形如\(x\quad op_1a_1op_2a_2...op_na_n\)的表达式。其中x是给定的变量,\(op_i\)是运算符+,*,^ 其中一个,其中^ 是乘法(具体的,a^b=\(a^b\),当前定义\(0^0\)=1)。运算符没有优先级,从左往右运算。还有一个模数P

有给定的两类操作

1 x0表示求x=x0时表达式的值模P

2 it cx将表达式中\(op_it\)修改为c,\(a_it\)修改为x

对于所有1操作输出答案取模P

n<=2e5,P都为合数原题中对于每个测试点的P都可以分成若干个小质数,如果没有这个条件CRT将无效

算法流程

不难想到,用线段树来维护每个区间对于每种x模P的情况。

具体的,一个线段树tree[N*4][0~(P-1)]。对于每个点it维护的区间转移方程为tree[it][0~(P-1)]=tree[it<<1|1][tree[it<<1][0~(P-1)]]较好理解,就是对于每个区间用左区间的值作为右区间的x得出的结果。

显然这样复杂度是\(O(Pnlog_2n)\)

考虑CRT优化复杂度。

对于测试点给的模数P,可以分成若干个质因数。我们不妨将P分解为若干个互质的因子,然后对与每个因子进行上面的线段树维护操作。

具体的对于一段区间it,枚举每个P分解的质因子,然后对于每个x模质因子后的数维护答案

代码转移:

inline void push_up(int i) {
    for (int I = 0; I < cnt; I++)
        for (int j = 0; j < pr[I]; j++) tree[i].val[I][j] = tree[i << 1 | 1].val[I][tree[i << 1].val[I][j]];
}

显然这种做法的复杂度为\(O(Snlog_2n)\)设S为P的因子和。

需要注意一个细节,对于乘方运算,如果幂过大会导致复杂度退化被卡,所以可以用扩展欧拉定理对于乘法运算优化

\[a^b\bmod c=a^{b\bmod \phi(c)+\phi(c)}\bmod c\quad b\geq \phi(c) \]

代码

#include <bits/stdc++.h>
using namespace std;
inline int read() {
    char c = getchar();
    int x = 0;
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x;
}
void write(int x) {
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
struct segt {
    int l, r, val[5][30];
} tree[800002];
int cnt, pr[5], ny[5], a[200005], phi;
char op[200005];
inline long long ksm(long long b, int p, long long mod) {
    long long res = 1;
    while (p) {
        if (p & 1)
            res = res * b % mod;
        b = b * b % mod;
        p >>= 1;
    }
    return res;
}
void exgcd(int a, int b, long long& x, long long& y) {
    if (!b) {
        x = 1, y = 0;
        return;
    }
    exgcd(b, a % b, x, y);
    long long t = x;
    x = y;
    y = t - a / b * y;
}
inline void push_up(int i) {
    for (int I = 0; I < cnt; I++)
        for (int j = 0; j < pr[I]; j++) tree[i].val[I][j] = tree[i << 1 | 1].val[I][tree[i << 1].val[I][j]];
}
inline void updlef(int I, int x) {
    if (op[x] == '+')
        for (int i = 0; i < cnt; i++)
            for (long long j = 0; j < pr[i]; j++) tree[I].val[i][j] = (j + a[x]) % pr[i];
    else if (op[x] == '*')
        for (int i = 0; i < cnt; i++)
            for (long long j = 0; j < pr[i]; j++) tree[I].val[i][j] = (j * a[x]) % pr[i];
    else {
        if (a[x])
            a[x] = a[x] % phi + phi;
        for (int i = 0; i < cnt; i++)
            for (long long j = 0; j < pr[i]; j++) tree[I].val[i][j] = ksm(j, a[x], pr[i]);
    }
}
void build(int i, int le, int ri) {
    tree[i].l = le, tree[i].r = ri;
    if (le == ri) {
        updlef(i, le);
        return;
    }
    int mid = (le + ri) >> 1;
    build(i << 1, le, mid), build(i << 1 | 1, mid + 1, ri);
    push_up(i);
}
void change(int i, int x) {
    if (tree[i].l == tree[i].r) {
        updlef(i, x);
        return;
    }
    if (x > tree[i << 1].r)
        change(i << 1 | 1, x);
    else
        change(i << 1, x);
    push_up(i);
}
int gcd(int x, int y) {
    if (!y)
        return x;
    return gcd(y, x % y);
}
int main() {
    int id = read(), n = read(), m = read(), P = read(), i;
    for (i = 1; i <= n; i++) {
        char c = getchar();
        while (c != '+' && c != '*' && c != '^') c = getchar();
        op[i] = c;
        a[i] = read();
    }
    if (id < 4) {
        while (m--) {
            int x = read();
            if (x - 1) {
                x = read();
                char c = getchar();
                while (c != '+' && c != '*' && c != '^') c = getchar();
                op[x] = c;
                a[x] = read();
            } else {
                long long res = read();
                for (i = 1; i <= n; i++)
                    if (op[i] == '+')
                        res = (res + a[i]) % P;
                    else if (op[i] == '*')
                        res = (res * a[i]) % P;
                    else
                        res = ksm(res, a[i], P);
                write(res), putchar('\n');
            }
        }
        return 0;
    }
    for (i = 1; i < P; i++)
        if (gcd(i, P) == 1)
            phi++;
    long long ta, tb;
    int x = P, y;
    for (i = 2; i * i <= x; i++)
        if (x % i == 0) {
            y = 1;
            while (x % i == 0) x /= i, y *= i;
            exgcd(P / y, y, ta, tb);
            ny[cnt] = (ta % y + y) % y;
            pr[cnt++] = y;
        }
    if (x > 1) {
        exgcd(P / x, x, ta, tb);
        // cout<<ta<<" "<<tb<<endl;
        pr[cnt] = x, ny[cnt] = (ta % x + x) % x;
        cnt++;
    }
    // cout<<";;;";
    // for(int i=0;i<cnt;++i){
    //     cout<<pr[i]<<":"<<ny[i]<<" ";
    // }
    // puts("");
    build(1, 1, n);
    while (m--) {
        x = read();
        if (x - 1) {
            x = read();
            char c = getchar();
            while (c != '+' && c != '*' && c != '^') c = getchar();
            op[x] = c;
            a[x] = read();
            change(1, x);
        } else {
            x = read();
            long long res = 0;
            for (i = 0; i < cnt; i++)
                res = (res + (long long)P / pr[i] * ny[i] % P * tree[1].val[i][x % pr[i]]) % P;
            write(res), putchar('\n');
        }
    }
    return 0;
}
posted @ 2022-10-13 10:03  SZBR_yzh  阅读(44)  评论(1编辑  收藏  举报