【题解】P4593 [TJOI2018] 教科书般的亵渎
之前整理的时候忘记写,现在补上。
思路
拉插求自然数幂和。
关于自然数幂和 \(\sum\limits_{i = 1}^n i^k\),已知是关于 \(n\) 的 \(k + 1\) 次多项式,可以用伯努利数 \(O(k \log k)\) 求,也可以直接拉插 \(O(k)\) 求。
拉插结论:若一个 \(n\) 次多项式 \(f\) 经过 \(n + 1\) 个点,则 \(f(x) = \sum\limits_{i = 1}^{n + 1} y_i \prod\limits_{j \neq i} \frac{x - x_j}{x_i - x_j}.\)
于是只需要求出 \(f = \sum\limits_{i = 1}^n i^k\) 在 \(n\) 取 \(1\) 到 \(k + 2\) 时的点值,就可以拉插计算得到答案。
回到这题,考虑到每次「亵渎」作用于全局,可以猜到最终使用「亵渎」的次数是一定的,观察得到使用次数 \(k = m + 1\).
假设没有不存在的血量,最终的分数是 \(\sum\limits_{i = 1}^n \sum\limits_{j = 1}^{n - i + 1} j^k\).
不存在的血量意味着需要在此处多使用一次亵渎,并且这里无法贡献分数,在上面的基础上容斥。方便起见,认为在 \(0\) 处有一头不存在的怪兽,编号为 \(a_0\)。最终答案是:
\(\sum\limits_{i = 0}^m \sum\limits_{j = 1}^{n - a_i} j^k - \sum\limits_{j = i + 1} (a_j - a_i)^k\).
朴素的复杂度是 \(O(m^2 \log |V| + mk)\).
代码
#include <cstdio>
#include <algorithm>
using namespace std;
#define int long long
const int maxn = 3.5e6 + 5;
const int maxm = 50 + 5;
const int mod = 1e9 + 7;
int t, n, m, k;
int a[maxm];
int pre[maxn], suf[maxn], fac[maxn];
int qpow(int base, int power)
{
int res = 1;
while (power)
{
if (power & 1) res = 1ll * res * base % mod;
base = 1ll * base * base % mod, power >>= 1;
}
return res;
}
int solve(int n, int k)
{
if (n <= 0) return 0;
int y = 0, res = 0;
pre[0] = fac[0] = suf[k + 3] = 1;
for (int i = 1; i <= k + 2; i++) pre[i] = 1ll * pre[i - 1] * (n - i) % mod;
for (int i = k + 2; i >= 1; i--) suf[i] = 1ll * suf[i + 1] * (n - i) % mod;
for (int i = 1; i <= k + 2; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
for (int i = 1; i <= k + 2; i++)
{
y = (y + qpow(i, k)) % mod;
int a = 1ll * pre[i - 1] * suf[i + 1] % mod;
int b = fac[i - 1] * ((k - i) & 1 ? -1ll : 1ll) * fac[k + 2 - i] % mod;
res = (res + 1ll * y * a % mod * qpow(b, mod - 2) % mod) % mod;
}
res = (res % mod + mod) % mod;
return res;
}
signed main()
{
scanf("%lld", &t);
a[0] = 0;
while (t--)
{
scanf("%lld%lld", &n, &m), k = m + 1;
for (int i = 1; i <= m; i++) scanf("%lld", &a[i]);
sort(a + 1, a + m + 1);
int ans = 0;
for (int i = 0; i <= m; i++)
{
ans = (ans + solve(n - a[i], k)) % mod;
for (int j = i + 1; j <= m; j++) ans = ((ans - qpow(a[j] - a[i], k)) % mod + mod) % mod;
}
printf("%lld\n", ans);
}
return 0;
}