【YBT2022寒假Day6 A】【luogu CF891E】随机减法 / Lust(EGF)
随机减法 / Lust
题目链接:YBT2022寒假Day6 A / luogu CF891E
题目大意
给你一个数组,每次随机选一个数减一,然后贡献增加除了这个数以外所有数的乘积,然后问你操作 k 次之后期望的贡献和。
思路
它这个除了以外某个数以外的乘积很不好搞,我们考虑一定把它弄成跟全部乘积有关的。
然后你发现每次减一,那全部乘积就减少了除了这个数以外所有数的乘积。
那每次的贡献就是全部乘积的减小量,那总贡献就是全部乘积总共减少的量,也就是: \(\prod a_i-\prod(a_i-b_i)\)(\(b_i\) 是 \(k\) 次操作选了 \(a_i\) 多少次)
然后 \(b_i\) 是每个的次数,你还要放的顺序,所以要乘上 \(\dfrac{k!}{\prod b_i!}\)
然后因为是期望,所以要乘概率:\(\dfrac{1}{n^k}\)
所以我们要得到的就是:\(\prod a_i-\sum\limits_{b}\dfrac{1}{n^k}\dfrac{k!}{\prod b_i!}\prod(a_i-b_i)\)
然后把 \(b\) 的放在一起:\(\dfrac{k!}{n^k}\prod\dfrac{a_i-b_i}{b_i!}\)
你发现右边这个 \(\prod\) 里面的好像是个 EGF 的性质,考虑试试:
\(f(i)=\sum\limits_{j=0}\dfrac{a_i-j}{j!}x^j\)
\(f(i)=\sum\limits_{j=0}(\dfrac{a_i}{j!}x^j-\dfrac{x}{(j-1)!}x^{j-1})\)
\(f(i)=\sum\limits_{j=0}\dfrac{a_i-x}{j!}x^j=(a_i-x)e^x\)
然后你是要每一项乘起来:
\(F(x)=\prod\limits_{i=1}^nf(i)=e^{nx}\prod\limits_{i=1}^n(a_i-x)\)
\(F(x)=\sum\limits_{i=0}n^i\dfrac{x^i}{i!}\prod\limits_{i=1}^n(a_i-x)\)
那么可以看出 \(\prod\limits_{i=1}^n(a_i-x)\) 是一个多项式,而且暴力算的复杂度是 \(O(n^2)\),是可以的。
那么假设我们得出来的结果是:\(\sum\limits c_ix^i\)
\(F(x)=\sum\limits_{i=0}n^i\dfrac{x^i}{i!}\sum\limits_{i=0}c_ix^i\)
那我们答案要的自然是第 \(k\) 项:
\([x^k]F(x)=\sum\limits_{i=0}^k n^{k-i}\dfrac{x^{k-i}}{(k-i)!}c_{i}x^{i}=\sum\limits_{i=0}^k\dfrac{n^{k-i}c_i}{(k-i)!}x^k\)
然后我们带回去:
\(E=\dfrac{k!}{n^k}*[x^k]F(x)\)
\(=\sum\limits_{i=0}^k\dfrac{k!}{n^k}\dfrac{n^{k-i}c_i}{(k-i)!}x^k\)
\(=\sum\limits_{i=0}^k\dfrac{k!c_i}{n^i(k-i)!}\)
\(=\sum\limits_{i=0}^k\dfrac{(k-i+1)(k-i+2)(...)(k)c_i}{n^i}\)
(不能直接算阶乘,所以要拆开来抵消,每次在原来基础上乘新的数)
然后最后 \(\prod a_i-E\) 即可。
代码
#include<cstdio>
#define ll long long
#define mo 1000000007
using namespace std;
ll n, k, a[5001], sum, c[5001], cc[5001];
void get_C() {
c[0] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++)
cc[j] = c[j - 1];
for (int j = 0; j <= n; j++)
c[j] = (c[j] * a[i] % mo - cc[j] + mo) % mo;
}
}
ll ksm(ll x, ll y) {
ll re = 1;
while (y) {
if (y & 1) re = re * x % mo;
x = x * x % mo;
y >>= 1;
}
return re;
}
int main() {
// freopen("calculate.in", "r", stdin);
// freopen("calculate.out", "w", stdout);
scanf("%lld %lld", &n, &k); sum = 1;
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]); sum = sum * a[i] % mo;
}
get_C();
ll xnfn = 0, tmp = 1, invn = ksm(n, mo - 2);
for (int i = 0; i <= n; i++) {
(xnfn += c[i] * tmp % mo) %= mo;
(tmp *= invn * (k - i) % mo) %= mo;
}
printf("%lld", (sum - xnfn + mo) % mo);
return 0;
}