P4916 [MtOI2018]魔力环
tag:组合计数,burnside
枚举所有旋转 \((x\to x+i)\),等价类一共有 \(\gcd(n,i)\) 个,每个等价类大小为 \(\frac n{\gcd(n,i)}\)。
然后问题变为,有一个长度为 \(n\) 的环,涂黑 \(m\) 个球,不能连续涂黑超过 \(k\) 个球,求方案数。
可以枚举第一个和最后一个白球之间有 \(i\) 个黑球,然后剩下有 \(m-i\) 个黑球,\(n-m\) 个白球,可以看作是将 \(m-i\) 个黑球塞到 \(n-m-1\) 个间隔中,每个间隔不能超过 \(k\) 个黑球。
这个问题可以用容斥解决,枚举强制让 \(i\) 个间隔超过 \(k\) 个,系数就是 \((-1)^i\)
所以得到
\[ans=\sum_{\frac n{\gcd(n,i)}|m}F(\gcd(n,i),\frac m{\frac n{\gcd(n,i)}},k)
\]
\[F(n,m,k)=\sum_{i=0}^k(i+1)G(m-i,n-m-1,k)
\]
\[G(n,m,k)=\sum_{i=0}^{\min(\lfloor\frac n{k+1}\rfloor,m)}\binom miH(n-i(k+1),m)
\]
\[H(n,m)=\binom{n+m-1}n
\]
注意计算 \(G\) 的时候考虑 \(m\leq k\) 的情况。
一次 \(F(n,m,k)\) 的复杂度是 \(O(k)\cdot O(\frac mk)=O(m)\)
总共复杂度就是 \(m\) 的约数和,松上界 \(O(mlogm)\)。
hehezhou的做法:
\[[x^{n-1}](\sum_{i=0}^k(i+1)x^{i})(\sum_{i=0}^kx^{i+1})^{n-m}
\]
好像化简以后差不多
原话:
只求那一项的系数可以手动展开
#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void Read(T &n){
char ch; bool flag=false;
while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
for(n=ch^48;isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
if(flag)n=-n;
}
enum{
MAXN = 200005,
MOD = 998244353
};
inline int ksm(int base, int k=MOD-2){
int res=1;
while(k){
if(k&1)
res = 1ll*res*base%MOD;
base = 1ll*base*base%MOD;
k >>= 1;
}
return res;
}
inline int inc(int a, int b){
a += b;
if(a>=MOD) a -= MOD;
return a;
}
inline int dec(int a, int b){
a -= b;
if(a<0) a += MOD;
return a;
}
inline void iinc(int &a, int b){a = inc(a,b);}
inline void ddec(int &a, int b){a = dec(a,b);}
inline void upd(int &a, long long b){a = (a+b)%MOD;}
int jc[MAXN], invjc[MAXN];
inline int C(int n, int m){return 1ll*jc[n]*invjc[m]%MOD*invjc[n-m]%MOD;}
inline void prework(int n){
jc[0] = 1; for(int i=1; i<=n; i++) jc[i] = 1ll*jc[i-1]*i%MOD;
invjc[n] = ksm(jc[n]); for(int i=n; i; i--) invjc[i-1] = 1ll*i*invjc[i]%MOD;
}
inline int calc1(int n, int m){return C(n+m-1,n);}
inline int calc2(int n, int m, int k){
int ans=0;
for(int i=0; i<=m and i*(k+1)<=n; i++)
if(i&1) ddec(ans,1ll*C(m,i)*calc1(n-i*(k+1),m)%MOD);
else upd(ans,1ll*C(m,i)*calc1(n-i*(k+1),m));
return ans;
}
inline int calc3(int n, int m, int k){
if(m<=k) return C(n,m);
int ans=0;
for(int i=0; i<=k; i++) upd(ans,1ll*(i+1)*calc2(m-i,n-m-1,k));
return ans;
}
int ans[MAXN];
int gcd(int a, int b){return b?gcd(b,a%b):a;}
int main(){
int n, m, k;
Read(n); Read(m); Read(k);
prework(MAXN-1);
memset(ans,-1,sizeof ans);
ans[0] = 0;
for(int i=1; i<=n; i++){
int x = n/gcd(n,i);
if(m%x) continue;
if(ans[x]==-1) ans[x] = calc3(n/x,m/x,k);
iinc(ans[0],ans[x]);
}
cout<<1ll*ans[0]*ksm(n)%MOD<<'\n';
return 0;
}