[BZOJ 3129] [Sdoi2013] 方程 【容斥+组合数取模+中国剩余定理】
题目链接:BZOJ - 3129
题目分析
使用隔板法的思想,如果没有任何限制条件,那么方案数就是 C(m - 1, n - 1)。
如果有一个限制条件是 xi >= Ai ,那么我们就可以将 m 减去 Ai - 1 ,相当于将这一部分固定分给 xi,就转化为无限制的情况了。
如果有一些限制条件是 xi <= Ai 呢?直接来求就不行了,但是注意到这样的限制不超过 8 个,我们可以使用容斥原理来求。
考虑容斥:考虑哪些限制条件被违反了,也就是说,有哪些限制为 xi <= Ai 却是 xi > Ai,这样就转化为了 xi >= Ai 的限制条件。
那么我们就可以在 2^8 * T(求组合数) 的时间内求出答案了。
怎么求这个组合数呢?直接预处理阶乘的逆元是不可以的,因为模数不都是质数。
我们要将模数拆成一个个 pi^ai 这样的形式,使得它们两两之间互质,就可以分别求出答案,最后再用中国剩余定理组合起来。
中国剩余定理:如果有n个方程 x = xi (mod mi) ,M = m1 * m2 * .. * mn ,那么在 mod M 的意义下,方程组有一个唯一解。
x = sigma(Mi * Inv(Mi) * xi) % M ,其中 Mi = M / mi ,Inv(Mi)是Mi在mod mi意义下的逆元。
那么我们的问题就是,如何求出 C(n, m) % (p^a) 。
这里就需要用到“组合数取模”了,专门用来求解这种问题。
使用类似于快速阶乘的方法,将组合数中分数线上下的阶乘都拆成 e * p^f 的形式,然后 e 直接计算,f 分数线上下相减之后再计算。
怎么将 x! 拆成 e * p^f 呢?
假设我们要 mod 的数是 p^a ,那么我们需要预处理出 [1, p^a - 1] 中除去 p 的倍数的其余数的前缀积(类似阶乘少了 p 的倍数)。
然后我们知道 [1, x] 中包含 p 的数有 x / p 个,我们将这些数中都提取出 1 个 p,那么就获得了 p^(x/p),然后这 x / p 个数就变成了 [1, x/p],就可以递归下去。
其余的部分可以分段来求,分成 [1, p^a - 1], [p^a + 1, p^a + p^a - 1] ..... 这样,每一段的积都是一样的 (mod p^a 意义下),直接快速幂就可以了。
最后还会剩下一段 [1, x % (p^a)] ,也是直接预处理出的值。
这样这道题就做完了(呼~)。
另外注意的是,在写代码的时候,我求逆元使用欧拉定理但是确用错了。
欧拉定理:a^phi(b) = 1 (mod b) 条件:gcd(a, b) = 1
注意是 a^phi(b) 而不是 a^(b-1) !当 b 不是质数的时候就跪了!
代码
#include <iostream> #include <cstdlib> #include <cstring> #include <algorithm> #include <cmath> #include <cstdio> using namespace std; typedef long long LL; typedef double LF; const int MaxP = 10201 + 15, MaxN1 = 8 + 5; int T, p, n, n1, n2, m, Top, Ans; int A[MaxN1]; LL Temp; LL Fac[10][MaxP], Pr[10], Pi[10], Pa[10], Phi_Pi[10], Mi[10], Inv_Mi[10], Xi[10]; LL Pow(LL a, LL b, LL Mod) { LL ret, f; ret = 1; f = a; while (b) { if (b & 1) { ret *= f; ret %= Mod; } b >>= 1; f *= f; f %= Mod; } return ret; } void Prepare() { int x, SqrtX; x = p; SqrtX = (int)sqrt((LF)x); Top = 0; for (int i = 2; i <= SqrtX; ++i) { if (x % i != 0) continue; Pr[++Top] = i; Pa[Top] = 0; Pi[Top] = 1; while (x % i == 0) { ++Pa[Top]; Pi[Top] *= i; x /= i; } Phi_Pi[Top] = Pi[Top] / Pr[Top] * (Pr[Top] - 1); } if (x > 1) { Pr[++Top] = x; Pa[Top] = 1; Pi[Top] = x; Phi_Pi[Top] = Pi[Top] - 1; } for (int i = 1; i <= Top; ++i) { Mi[i] = p / Pi[i]; Inv_Mi[i] = Pow(Mi[i], Phi_Pi[i] - 1, Pi[i]); Fac[i][0] = 1; for (int j = 1; j < Pi[i]; ++j) { if (j % Pr[i] != 0) Fac[i][j] = Fac[i][j - 1] * j % Pi[i]; else Fac[i][j] = Fac[i][j - 1]; } } } struct ES { LL e, f; }; ES Calc(int x, int k) { ES ret, tc; if (x < Pr[k]) { ret.e = Fac[k][x]; ret.f = 0; return ret; } ret.f = x / Pr[k]; tc = Calc(x / Pr[k], k); ret.f += tc.f; ret.e = tc.e * Fac[k][x % Pi[k]] % Pi[k]; ret.e = ret.e * Pow(Fac[k][Pi[k] - 1], x / Pi[k], Pi[k]) % Pi[k]; return ret; } LL C(int x, int y, int k) { LL ret; int pf; ES Ex, Ey, Exy; Ex = Calc(x, k); Ey = Calc(y, k); Exy = Calc(x - y, k); ret = Ex.e * Pow(Ey.e, Phi_Pi[k] - 1, Pi[k]) % Pi[k] * Pow(Exy.e, Phi_Pi[k] - 1, Pi[k]) % Pi[k]; pf = Ex.f - Ey.f - Exy.f; if (pf >= Pa[k]) ret = 0; else ret = ret * Pow(Pr[k], pf, Pi[k]) % Pi[k]; return ret; } int C(int x, int y) { if (x == y) return 1; if (x < y) return 0; if (y == 0) return 1; LL ret = 0; for (int i = 1; i <= Top; ++i) Xi[i] = C(x, y, i); for (int i = 1; i <= Top; ++i) { ret += Xi[i] * Mi[i] % p * Inv_Mi[i] % p; ret %= p; } return (int)ret; } void DFS(int x, int Cnt, int Sum) { if (x == n1) { if (Cnt & 1) Temp -= C(m - Sum - 1, n - 1); else Temp += C(m - Sum - 1, n - 1); Temp = (Temp % p + p) % p; return; } DFS(x + 1, Cnt, Sum); DFS(x + 1, Cnt + 1, Sum + A[x + 1]); } int Solve() { Temp = 0; DFS(1, 0, 0); DFS(1, 1, A[1]); return (int)Temp; } int main() { scanf("%d%d", &T, &p); Prepare(); for (int Case = 1; Case <= T; ++Case) { scanf("%d%d%d%d", &n, &n1, &n2, &m); for (int i = 1; i <= n1; ++i) scanf("%d", &A[i]); int Num; for (int i = 1; i <= n2; ++i) { scanf("%d", &Num); m -= Num - 1; } if (n1 > 0) Ans = Solve(); else Ans = C(m - 1, n - 1); printf("%d\n", Ans); } return 0; }