Luogu4916 魔力环 莫比乌斯反演、组合、生成函数
先不考虑循环同构的限制,那么对于一个满足条件的序列,如果它的循环节长度为\(d\),那么与它同构的环在答案中就会贡献\(d\)次。
所以如果设\(f_i\)表示循环节长度恰好为\(i\)的满足条件的序列个数(不考虑循环同构),那么最后的答案就是\(\sum \frac{f_i}{i}\)。
所以问题变成了如何求\(f_i\)。注意到\(f_i\)直接求不是很好求,考虑计算\(cnt(\frac{n}{d} , \frac{m}{d})\)表示珠子数为\(\frac{n}{d}\)、黑色珠子数为\(\frac{m}{d}\)、不考虑循环同构的合法方案数,不难注意到\(\sum\limits_{i | d} f_i = cnt(\frac{n}{d} , \frac{m}{d})\)。所以只需要把所有\(cnt(\frac{n}{d} , \frac{m}{d})\)算出来然后莫比乌斯反演一下就可以得到所有\(f_i\)。
然后我们将原问题变成了不需要考虑循环同构的问题\(cnt(a,b)\)。
对于\(cnt(a,b)\),考虑\(a-b\)个白色球产生的\(a-b+1\)个区间,每一个区间内放入的黑色球的数量不能超过\(k\),且首尾放入的球的数量之和不能超过\(k\)。也就是要求\(\sum\limits_{i=0}^{a-b} x_i = b , \forall i , x_i \leq k , x_0 + x_{a-b} = k\)的满足条件的\(x\)序列的数量。不难得到这个序列的生成函数为\((\sum\limits_{i=0}^k x^i)^{a-b-1} (\sum\limits_{i=0}^k(i+1)x^i)\),我们要求的是它的\(x^b\)项系数。显然多项式快速幂不够优秀,考虑更快的方法。
由\(\sum\limits_{i=0}^k x_i = \frac{1 - x^{k+1}}{1 - x}\),可以得到
\((\sum\limits_{i=0}^k x^i)^{a - b - 1} = (1 - x^{k+1})^{a-b-1}(1 - x)^{-(a-b-1)}\)
\(\begin{align*}\sum\limits_{i=0}^k(i + 1)x^i &= \sum\limits_{i=0}^k \sum\limits_{j=i}^k x^i \\ &= \sum\limits_{i=0}^k\frac{x^i - x^{k+1}}{1 - x} \\ &= \frac{\sum\limits_{i=0}^k x^i - (k+1)x^{k+1}}{1-x} \\ &= \frac{\frac{1-x^{k+1}}{1-x} - (k+1)x^{k+1}}{1 - x} = \frac{1 - (k + 2)x^{k+1} + (k + 1)x^{k + 2}}{(1-x)^2} \end{align*}\)
所以生成函数可以变形为\((1 - x^{k+1})^{a-b-1}(1-x)^{-(a - b + 1)}(1 - (k + 2)x^{k+1} + (k+1)x^{k+2})\)
注意到最后的一部分多项式只有\(3\)项,意味着前面两项的卷积只有\(x^b,x^{b - k - 1} , x^{b-k-2}\)项会对\(x^b\)项系数产生贡献
而由二项式定理可知
\((1 - x^{k+1})^{a-b-1} = \sum\limits_{i=0}^{a-b-1} \binom{a-b-1}{i} (-1)^i x^{ki+i},(1 - x)^{-(a-b+1)} = \sum\limits_{i=0}^{+\infty} \binom{-(a-b+1)}{i}(-1)^i x^i = \sum\limits_{i=0}^{+\infty} \binom{a - b + i}{i}x^i\)
故设\(A = \sum\limits_{ki+i+j = b} \binom{a-b-1}{i} (-1)^i \binom{a-b+j}{j} , B = \sum\limits_{ki+i+j = b - k - 1} \binom{a-b-1}{i} (-1)^i \binom{a-b+j}{j} , C = \sum\limits_{ki+i+j = b - k - 2} \binom{a-b-1}{i} (-1)^i \binom{a-b+j}{j}\)
那么\(cnt(a,b) = A - (k + 2)B + (k+1)C\)。\(ABC\)的计算式子都可以通过枚举\(i\)做到\(\frac{b}{k+1}\)的复杂度,所以计算\(cnt(a,b)\)的总复杂度为\(\frac{\sigma(n)}{k + 1}\),其中\(\sigma(n)\)为\(n\)的约数和,近似\(n\ log\ logn\)。
#include<iostream>
#include<cstdio>
#include<cstring>
//This code is written by Itst
using namespace std;
#define int long long
const int MAXN = 2e6 + 7 , MOD = 998244353;
int prm[MAXN] , jc[MAXN] , inv[MAXN] , mu[MAXN] , ans[MAXN];
int cnt , N , M , K;
bool nprm[MAXN];
inline int poww(int a , int b){
int times = 1;
while(b){
if(b & 1) times = times * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return times;
}
void init(){
mu[1] = 1;
for(int i = 2 ; i <= 1e6 ; ++i){
if(!nprm[i]){
prm[++cnt] = i;
mu[i] = -1;
}
for(int j = 1 ; j <= cnt && prm[j] * i <= 1e6 ; ++j){
nprm[prm[j] * i] = 0;
if(i % prm[j] == 0) break;
mu[i * prm[j]] = -1 * mu[i];
}
}
jc[0] = 1;
for(int i = 1 ; i <= 2e6 ; ++i)
jc[i] = jc[i - 1] * i % MOD;
inv[2000000] = poww(jc[2000000] , MOD - 2);
for(int i = 1999999 ; i >= 0 ; --i)
inv[i] = inv[i + 1] * (i + 1) % MOD;
}
int C(int b , int a){return b < a ? 0 : 1ll * jc[b] * inv[a] % MOD * inv[b - a] % MOD;}
int calc(int A , int B){
int sum1 = 0 , sum2 = 0 , sum3 = 0;
for(int i = 0 ; (K + 1) * i <= B ; ++i){
int j = B - (K + 1) * i;
sum1 = (sum1 + (i & 1 ? -1 : 1) * C(A - B - 1 , i) * C(A - B + j , j) % MOD + MOD) % MOD;
}
for(int i = 0 ; (K + 1) * i <= B - K - 1 ; ++i){
int j = B - K - 1 - (K + 1) * i;
sum2 = (sum2 + (i & 1 ? -1 : 1) * C(A - B - 1 , i) * C(A - B + j , j) % MOD + MOD) % MOD;
}
for(int i = 0 ; (K + 1) * i <= B - K - 2 ; ++i){
int j = B - K - 2 - (K + 1) * i;
sum3 = (sum3 + (i & 1 ? -1 : 1) * C(A - B - 1 , i) * C(A - B + j , j) % MOD + MOD) % MOD;
}
return (sum1 - (K + 2) * sum2 % MOD + (K + 1) * sum3 + MOD) % MOD;
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("gift.in","r",stdin);
freopen("gift.out","w",stdout);
#endif
init();
ios::sync_with_stdio(0);
cin >> N >> M >> K;
if(M == 0){puts("1"); return 0;}
for(int i = 1 ; i <= M ; ++i)
ans[N / i] = M % i == 0 && N % i == 0 ? calc(N / i , M / i) : 0;
for(int i = 1 ; i <= N ; ++i)
if(N % i == 0 && M % (N / i) == 0)
for(int j = 2 ; j * i <= N ; ++j)
ans[i * j] = (ans[i * j] + ans[i] * mu[j] + MOD) % MOD;
int sum = 0;
for(int i = 1 ; i <= N ; ++i)
if(N % i == 0 && M % (N / i) == 0)
sum = (sum + poww(i , MOD - 2) * ans[i]) % MOD;
cout << sum << '\n';
return 0;
}