拉格朗日插值法 (Lagrange interpolation approach) 学习笔记

Lagrange interpolation approach 是要解决一种如下的问题:

给定 \(n\) 个坐标,\((x_1, y_1), (x_2, y_2), \dots, (x_n, y_n)\),确定一个多项式 \(f(x) = a_0 + a_1x + a_2x^2 + \dots + a_dx^d\) 满足:

\[f(x_1) = y_1 \]

\[f(x_2) = y_2 \]

\[\dots \]

\[f(x_n) = y_n \]

一、高斯消元

回想一下学 FFT 的时候,使用点值表示法,用 \(k + 1\) 个点表示一个 \(k\) 次的多项式。

那么我们可以设:

\[f(x) = a_0 + a_1x + a_2x^2 + \dots + a_{n - 1}x^{n - 1} \]

然后我们可以列出方程组:

\[\begin{cases}a_0 + a_1x_1 + a_2x_1^2 + \dots + a_{n - 1}x_1^{n - 1} = y_1\\a_0 + a_1x_2 + a_2x_2^2 + \dots + a_{n - 1}x_2^{n - 1} = y_2\\\dots\\a_0 + a_1x_n + a_2x_n^2 + \dots + a_{n - 1}x_n^{n - 1} = y_n \end{cases} \]

然后就是一次简单的高斯消元,时间复杂度 \(O(n ^ 3)\)

二、拉格朗日插值

如果对推导感兴趣可以看这个:https://oi-wiki.org/math/poly/lagrange/

直接放公式:

\[f(x) = \sum_{i = 1}^{n} y_i \prod_{j \neq i} \frac{x - x_j}{x_i - x_j} \]

三、洛谷 P4781 【模板】拉格朗日插值

https://www.luogu.com.cn/problem/P4781

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/hash_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#ifdef LOCAL
#include "algo/debug.h"
#else
#define debug(...) 42
#endif
typedef long long ll;
typedef pair < int, int > PII;
typedef int itn;
mt19937 RND_MAKER (chrono :: steady_clock :: now ().time_since_epoch ().count ());
inline ll randomly (const ll l, const ll r) {return (RND_MAKER () ^ (1ull << 63)) % (r - l + 1) + l;}
#define int long long
const double pi = acos (-1);
//__gnu_pbds :: tree < Key, Mapped, Cmp_Fn = std :: less < Key >, Tag = rb_tree_tag, Node_Upadte = null_tree_node_update, Allocator = std :: allocator < char > > ;
//__gnu_pbds :: tree < PPS, __gnu_pbds :: null_type, less < PPS >, __gnu_pbds :: rb_tree_tag, __gnu_pbds :: tree_order_statistics_node_update > tr;
inline int read () {
	int x = 0, f = 0;
	char c = getchar ();
	for ( ; c < '0' || c > '9' ; c = getchar ()) f |= (c == '-');
	for ( ; c >= '0' && c <= '9' ; c = getchar ()) x = (x << 1) + (x << 3) + (c & 15);
	return !f ? x : -x;
}
const int mod = 998244353;
const int N = 2e3 + 5;
int x[N], y[N], n, k, ans;
inline int pow_mod (int a, int b, int p) {
	int res = 1;
	while (b) {
		if (b & 1) res = res * a % p;
		b >>= 1;
		a = a * a % p;
	}
	return res;
}
signed main () {
	n = read (), k = read ();
	for (int i = 1;i <= n; ++ i) {
		x[i] = read (), y[i] = read ();
	}
	for (int i = 1;i <= n; ++ i) {
		int tmp = y[i];
		for (int j = 1;j <= n; ++ j) {
			if (i != j) {
				int l = k - x[j];
				l = (l % mod + mod) % mod;
				tmp = tmp * l % mod;
				l = x[i] - x[j];
				l = (l % mod + mod) % mod;
				tmp = tmp * pow_mod (l, mod - 2, mod) % mod;
			}
		}
		ans = (ans + tmp) % mod;
	}
	printf ("%lld\n", (ans % mod + mod) % mod);
	return 0;
}

四、Codeforces 622F - The Sum of the k-th Powers

声明:博主数学很菜,所以并不会证明答案就是一个多项式 /kk /kk /kk。

我们把前 \(k + 2\) 个点都拎出来,比如第 \(i\) 个点就是 \(\displaystyle \left(i, \sum_{j = 1}^{i} j ^ k\right)\)

然后就做一遍拉格朗日插值,把 \(f(n)\) 输出即可。

但是这样直接做是 \(O(k^2)\) 的,\(k\) 是 1e6 级别的,怎么办?

夹带力度!爆推式子:

我们暂且将 \(w_u\) 设为 \(\displaystyle \sum_{i = 1}^{u} i^k\)

\[f(x) = \sum_{i = 1}^{n} w_i \prod_{j \neq i} \frac{x - j}{i - j} \]

这个 \(j \neq i\) 的限制条件十分的烦人,我们可以拆成两个区间:

\[f(x) = \sum_{i = 1}^{n} \left(w_i \times \prod_{j = 1}^{i - 1} \frac{x - j}{i - j} \times \prod_{j = i + 1}^{n}\frac{x - j}{i - j}\right) \]

\[f(x) = \sum_{i = 1}^{n} \left(w_i \times \frac{\prod\limits_{j = 1}^{i - 1} (x - j)}{(i - 1)!} \times \frac{\prod\limits_{j = i + 1}^{n} (x - j)}{(-1)^{n - i} \times (n - i)!}\right) \]

特别的,当 \(k \leq 0\) 时,\(k! = 1\)

分子可以直接前缀积算,然后分母可以预处理阶乘的逆元,就做完了。

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/hash_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#ifdef LOCAL
#include "algo/debug.h"
#else
#define debug(...) 42
#endif
typedef long long ll;
typedef pair < int, int > PII;
typedef int itn;
mt19937 RND_MAKER (chrono :: steady_clock :: now ().time_since_epoch ().count ());
inline ll randomly (const ll l, const ll r) {return (RND_MAKER () ^ (1ull << 63)) % (r - l + 1) + l;}
#define int long long
const double pi = acos (-1);
//__gnu_pbds :: tree < Key, Mapped, Cmp_Fn = std :: less < Key >, Tag = rb_tree_tag, Node_Upadte = null_tree_node_update, Allocator = std :: allocator < char > > ;
//__gnu_pbds :: tree < PPS, __gnu_pbds :: null_type, less < PPS >, __gnu_pbds :: rb_tree_tag, __gnu_pbds :: tree_order_statistics_node_update > tr;
inline int read () {
	int x = 0, f = 0;
	char c = getchar ();
	for ( ; c < '0' || c > '9' ; c = getchar ()) f |= (c == '-');
	for ( ; c >= '0' && c <= '9' ; c = getchar ()) x = (x << 1) + (x << 3) + (c & 15);
	return !f ? x : -x;
}
const int mod = 1e9 + 7;
const int N = 1e6 + 5;
int w[N], n, k, ans, fac[N], ifac[N], pre[N], suf[N];
inline int pow_mod (int a, int b, int p) {
	int res = 1;
	while (b) {
		if (b & 1) res = res * a % p;
		b >>= 1;
		a = a * a % p;
	}
	return res;
}
signed main () {
	fac[0] = ifac[0] = 1;
	for (int i = 1;i < N; ++ i) {
		fac[i] = fac[i - 1] * i % mod;
		ifac[i] = pow_mod (fac[i], mod - 2, mod);
	}
	k = read (), n = read ();
	pre[0] = 1, suf[n + 3] = 1;
	for (int i = 1;i <= n + 2; ++ i) {
		int cur = k - i;
		cur = (cur % mod + mod) % mod;
		pre[i] = pre[i - 1] * cur % mod;
	}
	for (int i = n + 2;i >= 0; -- i) {
		int cur = k - i;
		cur = (cur % mod + mod) % mod;
		suf[i] = suf[i + 1] * cur % mod;
	}
	for (int i = 1;i <= n + 2; ++ i) w[i] = (w[i - 1] + pow_mod (i, n, mod)) % mod;
	for (int i = 1;i <= n + 2; ++ i) {
		int tmp = w[i];
		tmp = tmp * ifac[i - 1] % mod;
		tmp = tmp * ifac[n + 2 - i] % mod;
		if ((n + 2 - i) & 1) tmp = mod - tmp;
		tmp = (tmp % mod + mod) % mod;
		tmp = tmp * pre[i - 1] % mod;
		tmp = tmp * suf[i + 1] % mod;
		ans = (ans + tmp) % mod;
	}
	printf ("%lld\n", (ans % mod + mod) % mod);
	return 0;
}
posted @ 2023-05-04 21:47  CountingGroup  阅读(44)  评论(0编辑  收藏  举报