[SNOI2019]数论

题目

考虑对于每一个\(a_i\)计算有多少个\(0<x\leq T-1\)满足\(x\equiv a_i(mod\ P)\)\(x\ mod\ Q \in B\)

显然\(x=a_i+k\times P\),先考虑一下这个\(k\)最大能取到多少,显然有\(a_i+k\times P\leq T-1\),所以\(k\)最大取到\(\left \lfloor \frac{T-1-a_i}{P} \right \rfloor\)

我们这样加下去,肯定会使得\(x\)\(mod\ Q\)意义下循环,我们尝试利用一下这个循环的性质

一旦\(k\times P\equiv 0(mod\ Q)\)就会循环,显然\(k=\frac{lcm(P,Q)}{P}=\frac{Q}{gcd(P,Q)}\)就是最小循环节

再考虑把这个问题转化成一个图论问题,我们对于每一个\(0<x<Q\),都建一条向\((x+P)\% Q\)的边,属于\(B\)集合中的点的点权是\(1\),我们要做的是求出每一个\(a_i\)\(\left \lfloor \frac{T-1-a_i}{P} \right \rfloor\)步到达的点的点权和

由于每一个点只有一条出边,那么就说明这个图里只会存在一些简单环,而且每一个环的大小都是\(\frac{Q}{gcd(P,Q)}\)

于是现在我们就可以把每一个环都找出来,维护一下每一个环的点权和,再维护一下环的前缀和,这样我们就能快速计算每一个点走一定步数到达的点的点权和了

边界情况需要好好判断一下

代码

#include <bits/stdc++.h>
#define LL long long
#define re register
const int maxn = 1e6 + 5;
LL T, ans, p[maxn];
std::vector<int> s[maxn], v[maxn];
int len, P, Q, n, m, a[maxn], b[maxn], tax[maxn], w[maxn], col[maxn], pos[maxn], num;
inline int read() {
    char c = getchar();int x = 0;
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - 48, c = getchar();
    return x;
}
int gcd(int a, int b) { return !b ? a : gcd(b, a % b); }
int dfs(int x, int c) {
    if (col[x]) return 0;
    col[x] = c;
    v[col[x]].push_back(x);
    return tax[x] + dfs((x + P) % Q, c);
}
inline int find(int l, int x) { return s[col[x]][pos[x] + l] - s[col[x]][pos[x]]; }
int main() {
    P = read(), Q = read(), n = read(), m = read(), scanf("%lld", &T);
    for (re int i = 1; i <= n; i++) a[i] = read();
    for (re int i = 1; i <= m; i++) b[i] = read();
    if (P > Q)
        std::swap(P, Q), std::swap(n, m), std::swap(a, b);
    len = Q / gcd(P, Q);
    for (re int i = 1; i <= m; i++) tax[b[i]] = 1;
    for (re int i = 1; i <= n; i++) p[i] = (T - 1 - a[i]) / P;
    for (re int i = 0; i < Q; i++)
        if (!col[i])
            ++num, w[num] = dfs(i, num);
    for (re int i = 1; i <= num; i++) {
        for (re int j = 0; j < v[i].size(); j++) pos[v[i][j]] = j;
        int t = v[i].size() - 1;
        for (re int j = 0; j < t; j++) v[i].push_back(v[i][j]);
        s[i].push_back(tax[v[i][0]]);
        for (re int j = 1; j < v[i].size(); j++) s[i].push_back(s[i][j - 1] + tax[v[i][j]]);
    }
    for (re int i = 1; i <= n; i++) {
        ans += 1ll * (p[i] / len) * (w[col[a[i]]]);
        ans += find(p[i] % len, a[i]) + tax[a[i]];
    }
    printf("%lld\n", ans);
    return 0;
}
posted @ 2019-06-25 14:07  asuldb  阅读(259)  评论(0编辑  收藏  举报