Codeforces 360D Levko and Sets (数论好题)
题意:有一个长度为n的数组a和一个长度为m的数组b,一个素数p。有n个集合,初始都只有一个1。现在,对(i从1到n)第i个集合执行以下操作:
对所有集合中的元素c,把c * (a[i] ^ b[j]) mod p 加入集合(j从1到m), 直到集合的元素不再增加为止。
问最后这n个集合的并有多少个元素?
n到1e4, m到1e5, p到1e9。
思路(官方题解)这题运用了很多数论的知识,不对数论有一定了解比较难做出这道题。
涉及的知识:原根,阶,欧拉定理,贝祖定理。
首先我们知道,x ^ y mod p = x ^ (y mod phi(p)) mod p(欧拉定理),而集合的操作可以看成a的指数相加,所以我们第一步先求出所有b和p - 1的gcd = t,这样每个集合的大小是(p - 1) / t,(贝祖定理),集合中的元素为a[i] ^ (k * t)。但是,由于a是不一定相同的,所以我们很难直接合并。我们可以转化成原根的形式,那么假设g ^ (r[i]) = a[i], ,令q[i] = gcd(r[i], p - 1)(贝祖定理)。我们不妨令g = g ^ t,因为t是一个常数。那么,现在集合中的元素为g ^ (q[i] * k), 我们现在需要求出q[i]。怎么求q[i]呢?我们只要求出a[i] ^ t的阶 = l,然后(p - 1) / l就是q[i], 为什么呢?因为(a[i] ^ t) ^ l mod p = 1 mod p = g ^ (p - 1) mod p => a[i] ^ t mod p= g ^ ((p - 1) / l) mod p, 这样就可以算出对应的q[i],即a[i] ^ t的原根中的指标。现在我们求出了q[i],那么所有g ^ (q[i] * k)会出现在集合中。一种想法是用筛法,但是p很大,不能筛。这里需要用到容斥原理。
注意,cf上很多代码的容斥是假的,可以被hack。
以下是来自cf评论区的两个数据:
2 1 13
3 5
1
答案:6
2 1 37
31 27
1
答案:8
正确容斥做法:先处理出p - 1的所有因子,从大到小枚举,枚举因子时判断是否有q[i],始得这个因子是q[i]的倍数,如果有,说明这个因子可以对答案产生贡献,再枚举比它大的因子,如果有因子是这个因子的倍数,那么减去这个因子的贡献。最后剩下的是这个因子的贡献。如果没有,跳过即可。均摊的复杂度应该是O(log ^ 2(p))。
代码:
#include <bits/stdc++.h> #define LL long long using namespace std; const int maxn = 100010; int n, m, p; int a[maxn], b[maxn], q[maxn], dp[maxn]; int qpow(int x, int y) { int ans = 1; for (; y; y >>= 1) { if(y & 1) ans = ((LL)ans * x) % p; x = ((LL)x * x) % p; } return ans; } vector<int> re; void div(int x) { for (int i = 1; i * i <= x; i++) { if(x % i == 0) { re.push_back(i); if(i * i != x) re.push_back(x / i); } } } int main() { scanf("%d%d%d", &n, &m, &p); for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); } int t = p - 1; for (int i = 1; i <= m; i++) { scanf("%d", &b[i]); t = __gcd(t, b[i]); } div(p - 1); sort(re.begin(), re.end()); for (int i = 1; i <= n; i++) { a[i] = qpow(a[i], t); for (int j = 0; j < re.size(); j++) { if(qpow(a[i], re[j]) == 1) { q[i] = (p - 1) / re[j]; break; } } } sort(q + 1, q + 1 + n); n = unique(q + 1, q + 1 + n) - (q + 1); int ans = 0; for (int j = re.size() - 1; j >= 0; j--) { bool flag = 0; for (int i = 1; i <= n; i++) { if(re[j] % q[i] == 0) { flag = 1; for (int k = j + 1; k < re.size(); k++) { if(re[k] % re[j] == 0) dp[j] -= dp[k]; } break; } } if(flag) { dp[j] += (p - 1) / re[j]; ans += dp[j]; } } printf("%d\n", ans); }