CF961G Partitions
https://www.luogu.com.cn/problem/CF961G
这nm居然是Div2G??
斯特林数及其反演好题
考场上能推出来难度不低啊
前置知识:斯特林数及其反演
首先考虑把答案写成每个物品的贡献求和的形式
A
N
S
=
∑
i
=
1
n
w
i
∑
j
=
1
n
j
C
n
−
1
j
−
1
S
(
n
−
j
,
k
−
1
)
ANS = \sum\limits_{i=1}^n w_i \sum\limits_{j=1}^n j C_{n-1}^{j-1}S(n-j,k-1)
ANS=i=1∑nwij=1∑njCn−1j−1S(n−j,k−1)
i
i
i表示哪个物品,
j
j
j表示这个集合大小
发现i,j无关,可以分开计算再乘起来
∑
j
=
1
n
j
C
n
−
1
j
−
1
S
(
n
−
j
,
k
−
1
)
\large \sum\limits_{j=1}^n j C_{n-1}^{j-1}S(n-j,k-1)
j=1∑njCn−1j−1S(n−j,k−1)
熟练的同学可以发现如果没有
j
j
j就是
S
(
n
,
k
)
S(n,k)
S(n,k),所以先分离出这一项出来
=
∑
j
=
1
n
C
n
−
1
j
−
1
S
(
n
−
j
,
k
−
1
)
+
∑
j
=
1
n
C
n
−
1
j
−
1
S
(
n
−
j
,
k
−
1
)
(
j
−
1
)
\large= \sum\limits_{j=1}^n C_{n-1}^{j-1}S(n-j,k-1) + \sum\limits_{j=1}^n C_{n-1}^{j-1}S(n-j,k-1)(j-1)
=j=1∑nCn−1j−1S(n−j,k−1)+j=1∑nCn−1j−1S(n−j,k−1)(j−1)
考虑如何将后面那项变成前面那种形式,发现主要是(j-1)比较难搞,考虑把组合数拆开
发现把(j-1)约掉以后再提一个(n-1)出来还是组合数
=
∑
j
=
1
n
C
n
−
1
j
−
1
S
(
n
−
j
,
k
−
1
)
+
(
n
−
1
)
∑
j
=
1
n
C
n
−
2
j
−
2
S
(
n
−
j
,
k
−
1
)
\large= \sum\limits_{j=1}^n C_{n-1}^{j-1}S(n-j,k-1) + (n-1)\sum\limits_{j=1}^n C_{n-2}^{j-2}S(n-j,k-1)
=j=1∑nCn−1j−1S(n−j,k−1)+(n−1)j=1∑nCn−2j−2S(n−j,k−1)
现在右边也变成一样的形式了,所以答案就是
S
(
n
,
k
)
+
(
n
−
1
)
S
(
n
−
1
,
k
)
\large S(n, k)+(n-1)S(n-1,k)
S(n,k)+(n−1)S(n−1,k)
直接算即可
code:
#include<bits/stdc++.h>
#define ll long long
#define mod 1000000007
#define N 200050
using namespace std;
ll qpow(ll x, ll y) {
ll ret = 1;
for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
return ret;
}
ll fac[N], ifac[N];
ll C(int n, int m) {
return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
void init(int n) {
fac[0] = 1;
for(int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod;
ifac[n] = qpow(fac[n], mod - 2);
for(int i = n - 1; i >= 0; i --) ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
ll S(int n, int m) {
ll ret = 0;
for(int i = 0; i <= m; i ++) ret += qpow(mod - 1, m - i) % mod * C(m, i) % mod * qpow(i, n) % mod, ret %= mod;
ret = ret * ifac[m] % mod;
// printf("%lld\n", ret);
return ret;
}
int n, k; ll w[N];
int main() {
scanf("%d%d", &n, &k);
init(n);
// printf("%lld", C(3, 2));
ll s = 0;
for(int i = 1; i <= n; i ++) scanf("%lld", &w[i]), s += w[i], s %= mod;
printf("%lld", s * (S(n, k) + (n - 1) * S(n - 1, k) % mod + mod) % mod);
return 0;
}