\(n + 1\) 个点可以唯一确定一个最高为 \(n\) 次的多项式。

普通情况:
\(f(k) = \sum_{i = 1}^{n + 1} y_i \prod_{i \neq j} \frac{k - x[j]}{x[i] - x[j]}\)

例题:https://www.luogu.com.cn/problem/P4781
给定多项式上的 \(n\) 个点,求出 \(f(x)\)

当横坐标是连续整数:
每个点的坐标\((x_i, y_i) = (i, i!)\)
\(f(k) = \sum_{i = 1}^{n + 1} y_i \frac{\prod_{i \neq j} (k - j)}{(k - i)(-1)^{n + 1 - i}(i - 1)!(n + 1 - i)!}\)

例题:https://codeforc.es/problemset/problem/622/F
给定 \(n\)\(k\),求解 \(\sum_{i = 1}^{n} i^k\)\(10^9 + 7\) 取模的值。

#include<bits/stdc++.h>
using namespace std;
using LL = long long;
template<int mod>
struct ModZ{
	LL x;
	constexpr ModZ() : x() {}
	constexpr ModZ(LL x) : x(norm(x % mod)) {}
	constexpr LL norm(LL x) {return (x % mod + mod) % mod;}
	constexpr ModZ power(ModZ a, LL b) {ModZ res = 1; for (; b; b /= 2, a *= a) if (b & 1) res *= a; return res;}
	constexpr ModZ inv() {return power(*this, mod - 2);}
	constexpr ModZ &operator *= (ModZ rhs) & {x = norm(x * rhs.x); return *this;}
	constexpr ModZ &operator += (ModZ rhs) & {x = norm(x + rhs.x); return *this;}
	constexpr ModZ &operator -= (ModZ rhs) & {x = norm(x - rhs.x); return *this;}
	constexpr ModZ &operator /= (ModZ rhs) & {return *this *= rhs.inv();}
	friend constexpr ModZ operator * (ModZ lhs, ModZ rhs) {ModZ res = lhs; res *= rhs; return res;}
	friend constexpr ModZ operator + (ModZ lhs, ModZ rhs) {ModZ res = lhs; res += rhs; return res;}
	friend constexpr ModZ operator - (ModZ lhs, ModZ rhs) {ModZ res = lhs; res -= rhs; return res;}
	friend constexpr ModZ operator / (ModZ lhs, ModZ rhs) {ModZ res = lhs; res /= rhs; return res;}
	friend constexpr istream &operator >> (istream &is, ModZ &a) {LL v; is >> v; a = ModZ(v); return is;}
	friend constexpr ostream &operator << (ostream &os, const ModZ &a) {return os << a.x;}
	friend constexpr bool operator == (ModZ lhs, ModZ rhs) {return lhs.x == rhs.x;}
	friend constexpr bool operator != (ModZ lhs, ModZ rhs) {return lhs.x != rhs.x;}
};
const int mod = 1e9 + 7;
using Z = ModZ<mod>;

struct Lagrange{
	int n;
	vector<Z> x, y, fac, invfac;
	Lagrange(int n){
		this -> n = n;
		x.resize(n + 3);
		y.resize(n + 3);
		fac.resize(n + 3);
		invfac.resize(n + 3);
		init(n);
	}
	void init(int n){
		iota(x.begin(), x.end(), 0);
		for (int i = 1; i <= n + 2; i ++ ){
			Z t;
			y[i] = y[i - 1] + t.power(i, n);
		}
		fac[0] = 1;
		for (int i = 1; i <= n + 2; i ++ ){
			fac[i] = fac[i - 1] * i;
		}
		invfac[n + 2] = fac[n + 2].inv();
		for (int i = n + 1; i >= 0; i -- ){
			invfac[i] = invfac[i + 1] * (i + 1);
		}
	}
	Z solve(LL k){
		if (k <= n + 2){
			return y[k];
		}
		vector<Z> sub(n + 3);
		for (int i = 1; i <= n + 2; i ++ ){
			sub[i] = k - x[i];
		}
		vector<Z> mul(n + 3);
		mul[0] = 1;
		for (int i = 1; i <= n + 2; i ++ ){
			mul[i] = mul[i - 1] * sub[i];
		}
		Z ans = 0;
		for (int i = 1; i <= n + 2; i ++ ){
			ans = ans + y[i] * mul[n + 2] * sub[i].inv() * pow(-1, n + 2 - i) * invfac[i - 1] * invfac[n + 2 - i];
		}
		return ans;
	}
};
int main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	int k, n;
	cin >> k >> n;
	Lagrange LI(n);
	cout << LI.solve(k) << "\n";
	return 0;
}
posted on 2023-06-04 19:44  Hamine  阅读(106)  评论(0编辑  收藏  举报