P4916 [MtOI2018]魔力环

tag:组合计数,burnside


枚举所有旋转 \((x\to x+i)\),等价类一共有 \(\gcd(n,i)\) 个,每个等价类大小为 \(\frac n{\gcd(n,i)}\)

然后问题变为,有一个长度为 \(n\) 的环,涂黑 \(m\) 个球,不能连续涂黑超过 \(k\) 个球,求方案数。

可以枚举第一个和最后一个白球之间有 \(i\) 个黑球,然后剩下有 \(m-i\) 个黑球,\(n-m\) 个白球,可以看作是将 \(m-i\) 个黑球塞到 \(n-m-1\) 个间隔中,每个间隔不能超过 \(k\) 个黑球。

这个问题可以用容斥解决,枚举强制让 \(i\) 个间隔超过 \(k\) 个,系数就是 \((-1)^i\)

所以得到

\[ans=\sum_{\frac n{\gcd(n,i)}|m}F(\gcd(n,i),\frac m{\frac n{\gcd(n,i)}},k) \]

\[F(n,m,k)=\sum_{i=0}^k(i+1)G(m-i,n-m-1,k) \]

\[G(n,m,k)=\sum_{i=0}^{\min(\lfloor\frac n{k+1}\rfloor,m)}\binom miH(n-i(k+1),m) \]

\[H(n,m)=\binom{n+m-1}n \]

注意计算 \(G\) 的时候考虑 \(m\leq k\) 的情况。

一次 \(F(n,m,k)\) 的复杂度是 \(O(k)\cdot O(\frac mk)=O(m)\)

总共复杂度就是 \(m\) 的约数和,松上界 \(O(mlogm)\)


hehezhou的做法:

\[[x^{n-1}](\sum_{i=0}^k(i+1)x^{i})(\sum_{i=0}^kx^{i+1})^{n-m} \]

好像化简以后差不多
原话:

只求那一项的系数可以手动展开


#include<bits/stdc++.h>
using namespace std;

template<typename T>
inline void Read(T &n){
	char ch; bool flag=false;
	while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
	for(n=ch^48;isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
	if(flag)n=-n;
}

enum{
	MAXN = 200005,
	MOD = 998244353
};

inline int ksm(int base, int k=MOD-2){
	int res=1;
	while(k){
		if(k&1)
			res = 1ll*res*base%MOD;
		base = 1ll*base*base%MOD;
		k >>= 1;
	}
	return res;
}

inline int inc(int a, int b){
	a += b;
	if(a>=MOD) a -= MOD;
	return a;
}

inline int dec(int a, int b){
	a -= b;
	if(a<0) a += MOD;
	return a;
}

inline void iinc(int &a, int b){a = inc(a,b);}
inline void ddec(int &a, int b){a = dec(a,b);}
inline void upd(int &a, long long b){a = (a+b)%MOD;}

int jc[MAXN], invjc[MAXN];
inline int C(int n, int m){return 1ll*jc[n]*invjc[m]%MOD*invjc[n-m]%MOD;}
inline void prework(int n){
	jc[0] = 1; for(int i=1; i<=n; i++) jc[i] = 1ll*jc[i-1]*i%MOD;
	invjc[n] = ksm(jc[n]); for(int i=n; i; i--) invjc[i-1] = 1ll*i*invjc[i]%MOD;
}

inline int calc1(int n, int m){return C(n+m-1,n);}

inline int calc2(int n, int m, int k){
	int ans=0;
	for(int i=0; i<=m and i*(k+1)<=n; i++)
		if(i&1) ddec(ans,1ll*C(m,i)*calc1(n-i*(k+1),m)%MOD);
		else upd(ans,1ll*C(m,i)*calc1(n-i*(k+1),m));
	return ans;
}

inline int calc3(int n, int m, int k){
	if(m<=k) return C(n,m);
	int ans=0;
	for(int i=0; i<=k; i++) upd(ans,1ll*(i+1)*calc2(m-i,n-m-1,k));
	return ans;
}

int ans[MAXN];

int gcd(int a, int b){return b?gcd(b,a%b):a;}

int main(){
	int n, m, k;
	Read(n); Read(m); Read(k);
	prework(MAXN-1);
	memset(ans,-1,sizeof ans);
	ans[0] = 0;
	for(int i=1; i<=n; i++){
		int x = n/gcd(n,i);
		if(m%x) continue;
		if(ans[x]==-1) ans[x] = calc3(n/x,m/x,k);
		iinc(ans[0],ans[x]);
	}
	cout<<1ll*ans[0]*ksm(n)%MOD<<'\n';
	return 0;
}
posted @ 2021-07-01 15:04  oisdoaiu  阅读(53)  评论(0编辑  收藏  举报