题解 莓良心

传送门

我我我不会斯特林数……导致考场上用了个 \(n^3\) 的DP推斯特林数并T飞

  • 关于斯特林数的计算公式:
    一个是:

    \[\begin{Bmatrix}n\\k\end{Bmatrix}=\frac{1}{k!}\sum\limits_{i=0}^k(-1)^i\binom{n}{i}(k-i)^n \]

    线性筛处理 \((k-i)^n\) 的话可以做到 \(O(k)\)
    它在做的事情大概是给每个位置赋一个 \([1, k]\) 间的标号,再容斥掉标号不满 \(k\) 种的
    因为最终的集合没有顺序,所以最后要乘上 \(\frac{1}{k!}\)

    还有一个式子是:

    \[\begin{Bmatrix}n\\m\end{Bmatrix}=\sum\limits_{i=0}^m\dfrac{(-1)^{m-i}i^n}{i!(m-i)!} \]

    来源这里,先咕

然后考场思路
每个组的贡献是 \(siz*\sum w_i\)
可以写成 \(\sum w_i*siz\)
于是枚举每个 \(w_i\) 所在的组的大小,乘上将剩下的数再分组的方案数
可惜需要计算一列斯特林数而不好做

正解又是个神仙思路

  • 当需要计算形如权值×所在组的大小的问题时,特别注意一个事情:
    这个贡献可以拆成点对的贡献,即在同一组内的一对点 \((u, v)\) 的贡献是 \(w_u+w_v\)
    特别地,\((u, u)\) 的贡献为 \(w_u\)

于是对于这个题,答案就是

\[\begin{Bmatrix}n\\k\end{Bmatrix}\sum w_i + \sum\limits_{u=1}^n\sum\limits_{v=u+1}^n(w_u+w_v)\begin{Bmatrix}n-1\\k\end{Bmatrix} \]

前一部分是考虑 \((u, u)\) 的贡献
后一部分是在枚举点对,要钦定 \(u, v\) 在同一组,所以是 \(n-1\)
于是式子可以展开化成

\[(\begin{Bmatrix}n\\k\end{Bmatrix}+(n-1)\begin{Bmatrix}n-1\\k\end{Bmatrix})\sum w_i \]

于是可以 \(O(n)\) 算出来需要的两个斯特林数

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5000010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k;
ll w[N], fac[N], inv[N];
const ll mod=998244353;
inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll dp[2010][2010], met, sum;
	void solve() {
		dp[0][0]=1;
		for (int j=1; j<=k; ++j) {
			for (int i=j; i<=n; ++i) {
				for (int t=1; t<=i; ++t) {
					dp[i][j]=(dp[i][j]+C(i-1, t-1)*dp[i-t][j-1]%mod)%mod;
					// cout<<"t: "<<t<<' '<<C(i, t)<<' '<<dp[i-t][j-1]<<endl;
				}
				printf("dp[%d][%d]=%lld\n", i, j, dp[i][j]);
			}
		}
		for (int i=1; i<=n; ++i) met=(met+i*C(n-1, i-1)%mod*dp[n-i][k-1]%mod)%mod;
		// cout<<"met: "<<met<<endl;
		for (int i=1; i<=n; ++i) sum=(sum+w[i])%mod;
		printf("%lld\n", sum*met%mod);
		exit(0);
	}
}

namespace task1{
	ll s[2010][2010], met, sum;
	void solve() {
		s[0][0]=1;
		for (int i=1; i<=n; ++i) {
			for (int j=1; j<=n; ++j) {
				s[i][j]=(s[i-1][j-1]+j*s[i-1][j])%mod;
				// printf("s[%d][%d]=%lld\n", i, j, s[i][j]);
			}
		}
		for (int i=1; i<=n; ++i) met=(met+i*C(n-1, i-1)%mod*s[n-i][k-1]%mod)%mod;
		// cout<<"met: "<<met<<endl;
		for (int i=1; i<=n; ++i) sum=(sum+w[i])%mod;
		printf("%lld\n", sum*met%mod);
		exit(0);
	}
}

namespace task{
	int pri[N], pcnt;
	ll sum, qp1[N], qp2[N];
	bool npri[N];
	ll s(int n, int k, ll* qp) {
		ll ans=0;
		for (int i=0; i<=n; ++i) {
			ans=(ans+(i&1?-1:1)*C(k, i)%mod*qp[k-i])%mod;
		}
		return ans*inv[k]%mod;
	}
	void solve() {
		qp1[1]=qp2[1]=1;
		for (int i=2; i<N; ++i) {
			if (!npri[i]) pri[++pcnt]=i, qp1[i]=qpow(i, n), qp2[i]=qpow(i, n-1);
			for (int j=1; j<=pcnt&&1ll*i*pri[j]<N; ++j) {
				npri[i*pri[j]]=1;
				qp1[i*pri[j]]=qp1[i]*qp1[pri[j]]%mod;
				qp2[i*pri[j]]=qp2[i]*qp2[pri[j]]%mod;
				if (!(i%pri[j])) break;
			}
		}
		for (int i=1; i<=n; ++i) sum=(sum+w[i])%mod;
		printf("%lld\n", (sum*(s(n, k, qp1)+1ll*(n-1)*s(n-1, k, qp2)%mod)%mod+mod)%mod);
		exit(0);
	}
}

signed main()
{
	freopen("ichigo.in", "r", stdin);
	freopen("ichigo.out", "w", stdout);

	n=read(); k=read();
	for (int i=1; i<=n; ++i) w[i]=read();
	fac[0]=fac[1]=1; inv[0]=inv[1]=1;
	for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
	for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2021-10-29 21:36  Administrator-09  阅读(4)  评论(0编辑  收藏  举报