题解 莓良心
我我我不会斯特林数……导致考场上用了个 \(n^3\) 的DP推斯特林数并T飞
-
关于斯特林数的计算公式:
一个是:\[\begin{Bmatrix}n\\k\end{Bmatrix}=\frac{1}{k!}\sum\limits_{i=0}^k(-1)^i\binom{n}{i}(k-i)^n \]线性筛处理 \((k-i)^n\) 的话可以做到 \(O(k)\)
它在做的事情大概是给每个位置赋一个 \([1, k]\) 间的标号,再容斥掉标号不满 \(k\) 种的
因为最终的集合没有顺序,所以最后要乘上 \(\frac{1}{k!}\)还有一个式子是:
\[\begin{Bmatrix}n\\m\end{Bmatrix}=\sum\limits_{i=0}^m\dfrac{(-1)^{m-i}i^n}{i!(m-i)!} \]来源这里,先咕
然后考场思路
每个组的贡献是 \(siz*\sum w_i\)
可以写成 \(\sum w_i*siz\)
于是枚举每个 \(w_i\) 所在的组的大小,乘上将剩下的数再分组的方案数
可惜需要计算一列斯特林数而不好做
正解又是个神仙思路
- 当需要计算形如权值×所在组的大小的问题时,特别注意一个事情:
这个贡献可以拆成点对的贡献,即在同一组内的一对点 \((u, v)\) 的贡献是 \(w_u+w_v\)
特别地,\((u, u)\) 的贡献为 \(w_u\)
于是对于这个题,答案就是
\[\begin{Bmatrix}n\\k\end{Bmatrix}\sum w_i + \sum\limits_{u=1}^n\sum\limits_{v=u+1}^n(w_u+w_v)\begin{Bmatrix}n-1\\k\end{Bmatrix}
\]
前一部分是考虑 \((u, u)\) 的贡献
后一部分是在枚举点对,要钦定 \(u, v\) 在同一组,所以是 \(n-1\)
于是式子可以展开化成
\[(\begin{Bmatrix}n\\k\end{Bmatrix}+(n-1)\begin{Bmatrix}n-1\\k\end{Bmatrix})\sum w_i
\]
于是可以 \(O(n)\) 算出来需要的两个斯特林数
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5000010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, k;
ll w[N], fac[N], inv[N];
const ll mod=998244353;
inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
ll dp[2010][2010], met, sum;
void solve() {
dp[0][0]=1;
for (int j=1; j<=k; ++j) {
for (int i=j; i<=n; ++i) {
for (int t=1; t<=i; ++t) {
dp[i][j]=(dp[i][j]+C(i-1, t-1)*dp[i-t][j-1]%mod)%mod;
// cout<<"t: "<<t<<' '<<C(i, t)<<' '<<dp[i-t][j-1]<<endl;
}
printf("dp[%d][%d]=%lld\n", i, j, dp[i][j]);
}
}
for (int i=1; i<=n; ++i) met=(met+i*C(n-1, i-1)%mod*dp[n-i][k-1]%mod)%mod;
// cout<<"met: "<<met<<endl;
for (int i=1; i<=n; ++i) sum=(sum+w[i])%mod;
printf("%lld\n", sum*met%mod);
exit(0);
}
}
namespace task1{
ll s[2010][2010], met, sum;
void solve() {
s[0][0]=1;
for (int i=1; i<=n; ++i) {
for (int j=1; j<=n; ++j) {
s[i][j]=(s[i-1][j-1]+j*s[i-1][j])%mod;
// printf("s[%d][%d]=%lld\n", i, j, s[i][j]);
}
}
for (int i=1; i<=n; ++i) met=(met+i*C(n-1, i-1)%mod*s[n-i][k-1]%mod)%mod;
// cout<<"met: "<<met<<endl;
for (int i=1; i<=n; ++i) sum=(sum+w[i])%mod;
printf("%lld\n", sum*met%mod);
exit(0);
}
}
namespace task{
int pri[N], pcnt;
ll sum, qp1[N], qp2[N];
bool npri[N];
ll s(int n, int k, ll* qp) {
ll ans=0;
for (int i=0; i<=n; ++i) {
ans=(ans+(i&1?-1:1)*C(k, i)%mod*qp[k-i])%mod;
}
return ans*inv[k]%mod;
}
void solve() {
qp1[1]=qp2[1]=1;
for (int i=2; i<N; ++i) {
if (!npri[i]) pri[++pcnt]=i, qp1[i]=qpow(i, n), qp2[i]=qpow(i, n-1);
for (int j=1; j<=pcnt&&1ll*i*pri[j]<N; ++j) {
npri[i*pri[j]]=1;
qp1[i*pri[j]]=qp1[i]*qp1[pri[j]]%mod;
qp2[i*pri[j]]=qp2[i]*qp2[pri[j]]%mod;
if (!(i%pri[j])) break;
}
}
for (int i=1; i<=n; ++i) sum=(sum+w[i])%mod;
printf("%lld\n", (sum*(s(n, k, qp1)+1ll*(n-1)*s(n-1, k, qp2)%mod)%mod+mod)%mod);
exit(0);
}
}
signed main()
{
freopen("ichigo.in", "r", stdin);
freopen("ichigo.out", "w", stdout);
n=read(); k=read();
for (int i=1; i<=n; ++i) w[i]=read();
fac[0]=fac[1]=1; inv[0]=inv[1]=1;
for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
// force::solve();
task::solve();
return 0;
}