\(n + 1\) 个点可以唯一确定一个最高为 \(n\) 次的多项式。
普通情况:
\(f(k) = \sum_{i = 1}^{n + 1} y_i \prod_{i \neq j} \frac{k - x[j]}{x[i] - x[j]}\)
例题:https://www.luogu.com.cn/problem/P4781
给定多项式上的 \(n\) 个点,求出 \(f(x)\)。
当横坐标是连续整数:
每个点的坐标\((x_i, y_i) = (i, i!)\)
\(f(k) = \sum_{i = 1}^{n + 1} y_i \frac{\prod_{i \neq j} (k - j)}{(k - i)(-1)^{n + 1 - i}(i - 1)!(n + 1 - i)!}\)
例题:https://codeforc.es/problemset/problem/622/F
给定 \(n\) 和 \(k\),求解 \(\sum_{i = 1}^{n} i^k\) 对 \(10^9 + 7\) 取模的值。
#include<bits/stdc++.h>
using namespace std;
using LL = long long;
template<int mod>
struct ModZ{
LL x;
constexpr ModZ() : x() {}
constexpr ModZ(LL x) : x(norm(x % mod)) {}
constexpr LL norm(LL x) {return (x % mod + mod) % mod;}
constexpr ModZ power(ModZ a, LL b) {ModZ res = 1; for (; b; b /= 2, a *= a) if (b & 1) res *= a; return res;}
constexpr ModZ inv() {return power(*this, mod - 2);}
constexpr ModZ &operator *= (ModZ rhs) & {x = norm(x * rhs.x); return *this;}
constexpr ModZ &operator += (ModZ rhs) & {x = norm(x + rhs.x); return *this;}
constexpr ModZ &operator -= (ModZ rhs) & {x = norm(x - rhs.x); return *this;}
constexpr ModZ &operator /= (ModZ rhs) & {return *this *= rhs.inv();}
friend constexpr ModZ operator * (ModZ lhs, ModZ rhs) {ModZ res = lhs; res *= rhs; return res;}
friend constexpr ModZ operator + (ModZ lhs, ModZ rhs) {ModZ res = lhs; res += rhs; return res;}
friend constexpr ModZ operator - (ModZ lhs, ModZ rhs) {ModZ res = lhs; res -= rhs; return res;}
friend constexpr ModZ operator / (ModZ lhs, ModZ rhs) {ModZ res = lhs; res /= rhs; return res;}
friend constexpr istream &operator >> (istream &is, ModZ &a) {LL v; is >> v; a = ModZ(v); return is;}
friend constexpr ostream &operator << (ostream &os, const ModZ &a) {return os << a.x;}
friend constexpr bool operator == (ModZ lhs, ModZ rhs) {return lhs.x == rhs.x;}
friend constexpr bool operator != (ModZ lhs, ModZ rhs) {return lhs.x != rhs.x;}
};
const int mod = 1e9 + 7;
using Z = ModZ<mod>;
struct Lagrange{
int n;
vector<Z> x, y, fac, invfac;
Lagrange(int n){
this -> n = n;
x.resize(n + 3);
y.resize(n + 3);
fac.resize(n + 3);
invfac.resize(n + 3);
init(n);
}
void init(int n){
iota(x.begin(), x.end(), 0);
for (int i = 1; i <= n + 2; i ++ ){
Z t;
y[i] = y[i - 1] + t.power(i, n);
}
fac[0] = 1;
for (int i = 1; i <= n + 2; i ++ ){
fac[i] = fac[i - 1] * i;
}
invfac[n + 2] = fac[n + 2].inv();
for (int i = n + 1; i >= 0; i -- ){
invfac[i] = invfac[i + 1] * (i + 1);
}
}
Z solve(LL k){
if (k <= n + 2){
return y[k];
}
vector<Z> sub(n + 3);
for (int i = 1; i <= n + 2; i ++ ){
sub[i] = k - x[i];
}
vector<Z> mul(n + 3);
mul[0] = 1;
for (int i = 1; i <= n + 2; i ++ ){
mul[i] = mul[i - 1] * sub[i];
}
Z ans = 0;
for (int i = 1; i <= n + 2; i ++ ){
ans = ans + y[i] * mul[n + 2] * sub[i].inv() * pow(-1, n + 2 - i) * invfac[i - 1] * invfac[n + 2 - i];
}
return ans;
}
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int k, n;
cin >> k >> n;
Lagrange LI(n);
cout << LI.solve(k) << "\n";
return 0;
}