HDU 6036 - Division Game | 2017 Multi-University Training Contest 1
/* HDU 6036 - Division Game [ 组合数学,NTT ] | 2017 Multi-University Training Contest 1 题意: k堆石子围成一个圈,数量均为n,编号为0至k-1 第i轮可以操作第 (i+1) mod k 堆石子,必须拿石子且原石子数量要求整除操作后石子数量 任意一堆石子只剩一颗后停止游戏,问游戏停止在第i堆的方案数 限制: n很大,按唯一分解定理形式给出 n = p1^e1 * p2^e2 * ... * pm^em m,k <= 10 ∑ei <= 1e5 分析: 设 w = ∑ei ,则每堆石子最多操作w次 设 F(x) 为一个堆操作 x 次恰好变为1的方案数,则一个堆操作 x-1 次后不变为 1 的方案数也为 F(x) 对于石子堆i的方案数,枚举总操作次数x,设此时方案数为 ans[i][x] 则显然当石子堆i操作x次时,0 < j < i 的石子堆 j 均操作x次,而i < j <= k 的石子堆均操作x-1 故 ans[i][x] = [0<j<i的堆操作x次后不为1] * [堆i操作x次恰好为1] * [i<j<=k的堆操作x-1次后不为1] = F(x+1)^(i-1) * F(x) * F(x)^(k-i) = F(x+1)^(i-1) * F(x)^(k-i+1) 现在研究一下上式 x 的取值范围 由于每个堆至多操作 w 次,则 [0<j<i的堆操作x次后不为1] 的 x ∈[0 , w-1] [堆i操作x次恰好为1] 的 x ∈[0 , w] [i<j<=k的堆操作x-1次后不为1] 的 x ∈[0 , w+1] 易得出结论 x ∈[0 , w-1] 但当 i = 1 时,0<j<i的堆 不存在,故其可以取到 x = w 所以分类讨论: 当 i = 1 时,ans[i][x] = F(x+1)^(i-1) * F(x)^(k-i+1) + F(w) , x ∈[0 , w-1] 当 i > 1 时,ans[i][x] = F(x+1)^(i-1) * F(x)^(k-i+1) , x ∈[0 , w-1] 研究如何求出 F(x): 此时有两个限制条件 1. 对于每个质因子pi,在 x 次内取完 即ei个相同的数字分成x组,允许空组,根据挡板法,方案数为 Comb(ei+x-1, x-1) 根据乘法原理 总方案数 F'(x) = ∏ Comb(ei+x-1, x-1) [1<=i<=m] 但由于存在第二个限制,F(x) != F'(x) 2. 每一次至少存在一个质因子被取 则根据容斥原理 F(x) = 随意取的方案数 - 某一次什么都没有取的方案数 + 某两次什么都没有取的方案数 - 某三次什么都没有取的方案数 + ... 由于 x次中k次没有取的方案数 = Comb(x,k) * x-k次恰好取完的方案数 = Comb(x,k) *F(x-k) 则 F(x) = F'(x) - Comb(x,1) * F(x-1) + Comb(x,2) * F(x-2) - Comb(x,3) * F(x-3) + ... = Σ (-1)^i * C(x, i) * F(x-i) [0<=i<x] 将组合数打开优化: F(x) = Σ (-1)^i * x! / i! / (x-i)! * F(x-i) [0<=i<x] F(x)/x! = Σ (-1)^i/i! * F(x-i)/(x-i)! [0<=i<x] 可以看出是卷积,再考虑模数特殊,用 NTT 优化 对代码常数有一定要求 */ #include <bits/stdc++.h> using namespace std; #define LL long long const int N = 100005; const int MOD = 985661441; namespace NTT{ const int G = 3; const int NUM = 20; int wn[20]; int mul(int x, int y) { return (LL)x*y%MOD; } int PowMod(int a, int b) { int res = 1; a %= MOD; while (b) { if (b&1) res = mul(res, a); a = mul(a, a); b >>= 1; } return res; } void GetWn() { for (int i = 0; i < NUM; i++) { int t = 1<<i; wn[i] = PowMod(G, (MOD-1)/t); } } void Change(int a[], int len) { int i, j, k; for (i = 1, j = len/2; i < len-1; i++) { if (i < j) swap(a[i], a[j]); k = len/2; while (j >= k) { j -= k; k /= 2; } if (j < k) j += k; } } void NTT(int a[], int len, int on) { Change(a, len); int id = 0; for (int h = 2; h <= len; h <<= 1) { id++; for (int j = 0; j < len; j += h) { int w = 1; for (int k = j; k < j + h/2; k++) { int u = a[k] % MOD; int t = mul(a[k+h/2], w); a[k] = (u+t) % MOD; a[k+h/2] = ((u-t)%MOD + MOD) % MOD; w = mul(w, wn[id]); } } } if (on == -1) { for (int i = 1; i < len/2; i++) swap(a[i], a[len-i]); int inv = PowMod(len, MOD-2); for (int i = 0; i < len; i++) a[i] = mul(a[i], inv); } } } int a[N<<3], b[N<<3]; namespace COMB{ int F[N<<1], Finv[N<<1], inv[N<<1]; void init() { inv[1] = 1; for (int i = 2; i < N<<1; i++) { inv[i] = (LL)(MOD - MOD/i) * inv[MOD%i] % MOD; } F[0] = Finv[0] = 1; for (int i = 1; i < N<<1; i++) { F[i] = (LL)F[i-1] * i % MOD; Finv[i] = (LL)Finv[i-1] * inv[i] % MOD; } } int comb(int n, int m) { if (m < 0 || m > n) return 0; return (LL)F[n] * Finv[n-m] % MOD * Finv[m] % MOD; } } using namespace COMB; int e[20], m, k, n, tt; int g[N], ans[20]; int pa[2][N]; int main() { int i, x, len, tt = 0; NTT::GetWn(); init(); while (~scanf("%d%d", &m, &k)) { n = 0; for (i = 1; i <= m; ++i) { scanf("%*d%d", &e[i]); n += e[i]; } ++n; for (x = 0; x < n; ++x) { g[x] = 1; for (i = 1; i <= m; ++i) g[x] = (LL)g[x] * comb(e[i]+x-1, x-1) % MOD; } for (i = 0; i < n; ++i) { a[i] = (i%2 ? MOD - Finv[i] : Finv[i]); b[i] = (LL)g[i] * Finv[i] % MOD; } len = 1; while (len < n*2) len <<= 1; for (i = n; i < len; ++i) a[i] = b[i] = 0; NTT::NTT(a, len, 1); NTT::NTT(b, len, 1); for (int i = 0; i < len; ++i) a[i] = NTT::mul(a[i], b[i]); NTT::NTT(a, len, -1); for (i = 0; i < n; ++i) a[i] = (LL)a[i] * F[i] % MOD; memset(ans, 0, sizeof(ans)); a[n] = 0; int pre = 1, cur = 0; for (x = 1; x < n; ++x) { pre ^= 1, cur ^= 1; pa[cur][0] = 1; for (i = 1; i <= k; ++i) pa[cur][i] = (LL)pa[cur][i-1] * a[x] % MOD; if (x > 1) for (i = 1; i <= k; ++i) ans[i] = (ans[i] + (LL)pa[cur][i-1] * pa[pre][k-i+1]) % MOD; } ans[1] += pa[cur][k]; if (ans[1] > MOD) ans[1] -= MOD; printf("Case #%d:", ++tt); for (i = 1; i <= k; ++i) printf(" %d", ans[i]); puts(""); } }
*修正了一下写错的部分
我自倾杯,君且随意