CF961G Partitions

https://www.luogu.com.cn/problem/CF961G

这nm居然是Div2G??
斯特林数及其反演好题
考场上能推出来难度不低啊

前置知识:斯特林数及其反演

首先考虑把答案写成每个物品的贡献求和的形式
A N S = ∑ i = 1 n w i ∑ j = 1 n j C n − 1 j − 1 S ( n − j , k − 1 ) ANS = \sum\limits_{i=1}^n w_i \sum\limits_{j=1}^n j C_{n-1}^{j-1}S(n-j,k-1) ANS=i=1nwij=1njCn1j1S(nj,k1)
i i i表示哪个物品, j j j表示这个集合大小

发现i,j无关,可以分开计算再乘起来
∑ j = 1 n j C n − 1 j − 1 S ( n − j , k − 1 ) \large \sum\limits_{j=1}^n j C_{n-1}^{j-1}S(n-j,k-1) j=1njCn1j1S(nj,k1)
熟练的同学可以发现如果没有 j j j就是 S ( n , k ) S(n,k) S(n,k),所以先分离出这一项出来

= ∑ j = 1 n C n − 1 j − 1 S ( n − j , k − 1 ) + ∑ j = 1 n C n − 1 j − 1 S ( n − j , k − 1 ) ( j − 1 ) \large= \sum\limits_{j=1}^n C_{n-1}^{j-1}S(n-j,k-1) + \sum\limits_{j=1}^n C_{n-1}^{j-1}S(n-j,k-1)(j-1) =j=1nCn1j1S(nj,k1)+j=1nCn1j1S(nj,k1)(j1)
考虑如何将后面那项变成前面那种形式,发现主要是(j-1)比较难搞,考虑把组合数拆开
发现把(j-1)约掉以后再提一个(n-1)出来还是组合数
= ∑ j = 1 n C n − 1 j − 1 S ( n − j , k − 1 ) + ( n − 1 ) ∑ j = 1 n C n − 2 j − 2 S ( n − j , k − 1 ) \large= \sum\limits_{j=1}^n C_{n-1}^{j-1}S(n-j,k-1) + (n-1)\sum\limits_{j=1}^n C_{n-2}^{j-2}S(n-j,k-1) =j=1nCn1j1S(nj,k1)+(n1)j=1nCn2j2S(nj,k1)
现在右边也变成一样的形式了,所以答案就是
S ( n , k ) + ( n − 1 ) S ( n − 1 , k ) \large S(n, k)+(n-1)S(n-1,k) S(n,k)+(n1)S(n1,k)

直接算即可
code:

#include<bits/stdc++.h>
#define ll long long
#define mod 1000000007
#define N 200050
using namespace std;
ll qpow(ll x, ll y) {
    ll ret = 1;
    for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
    return ret;
}
ll fac[N], ifac[N];
ll C(int n, int m) {
    return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
void init(int n) {
    fac[0] = 1;
    for(int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod;
    ifac[n] = qpow(fac[n], mod - 2);
    for(int i = n - 1; i >= 0; i --) ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
ll S(int n, int m) {
    ll ret = 0;
    for(int i = 0; i <= m; i ++) ret += qpow(mod - 1, m - i) % mod * C(m, i) % mod * qpow(i, n) % mod, ret %= mod;
    ret = ret * ifac[m] % mod;
//    printf("%lld\n", ret);
    return ret;
}
int n, k; ll w[N];
int main() {
    scanf("%d%d", &n, &k);
    init(n);
   // printf("%lld", C(3, 2));
    ll s = 0;
    for(int i = 1; i <= n; i ++) scanf("%lld", &w[i]), s += w[i], s %= mod;
    printf("%lld", s * (S(n, k) + (n - 1) * S(n - 1, k) % mod + mod) % mod);
    return 0;
}
posted @ 2021-07-30 21:41  lahlah  阅读(104)  评论(0编辑  收藏  举报