【数学】拉格朗日插值

【数学】拉格朗日插值

题目描述

由小学知识可知 \(n\) 个点 \((x_i,y_i)\) 可以唯一地确定一个多项式 \(y = f(x)\)

现在,给定这 \(n\) 个点,请你确定这个多项式,并求出 \(f(k) \bmod 998244353\) 的值。

\(1 \le n \leq 2\times 10^3\)\(1 \le x_i,y_i,k < 998244353\)\(x_i\) 两两不同。

算法描述

考虑构造一个多项式,满足\(\forall x_i,f(x_i) = y_i\)。考虑用一种手法,使若干项加起来,当\(x = x_i\)时,其他项全部是0,只有一项是\(y_i\)。即:

\[f(x) = \sum_{i = 1}^{n}y_ig_i(x_i) \]

其中

\[g_i(x) = \begin{cases} 0\ \ \ x = x_j(j \neq i) \\ 1\ \ \ x = x_i \end{cases} \]

对于\(g_i(x)\)只有在\(x_i\)这个值时才等于1这个性质,我们考虑用乘法来实现,即:

\[\prod_{i \neq j}(x - x_j) \]

这样就能做到当\(x = x_i\)时它不为0,\(x = x_j(j \neq i)\)时为0(其他值不做要求)

那么怎样才能让\(x = x_i\)时等于1呢?观察到\(x = x_i\)时这个值等于

\[\prod_{i \neq j}(x_i - x_j) \]

所以除掉一个相同的值就好了。所以

\[g_i(x) = \prod_{i \neq j}\frac {(x - x_j)}{(x_i - x_j)} \]

\[f(x) = \sum_{i = 1}^ny_i\prod_{i \neq j}\frac {(x - x_j)}{(x_i - x_j)} \]

时间复杂度\(O(n^2)\)

对于这道题\(n \leq 2 \times 10^3\)的情况(也是一般情况),直接做即可。

Code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MOD = 998244353;
inline ll ksm(ll base,ll pts)
{
    ll ret = 1;
    for(;pts > 0;pts >>= 1,base = base * base % MOD)
        if(pts & 1)
            ret = ret * base % MOD;
    return ret; 
}
ll n,k,x[2005],y[2005];
inline ll g(ll i,ll X)
{
    ll ret = 1;
    for(int j = 1;j <= n;j++)
    {
        if(i == j) continue;
        ret = ret * (X - x[j] + MOD) % MOD;
    }
    for(int j = 1;j <= n;j++)
    {
        if(i == j) continue;
        ret = ret * ksm((x[i] - x[j] + MOD) % MOD,MOD - 2) % MOD;
    }
    return ret;
}
int main()
{
    cin>>n>>k;
    for(int i = 1;i <= n;i++)
        cin>>x[i]>>y[i];
    ll ans = 0;
    for(int i = 1;i <= n;i++)
        ans += y[i] * g(i,k) % MOD,ans %= MOD;
    cout<<ans;
    return 0;
 } 

但是对于很多更加灵活的题目,时间复杂度要求更高,并且点值是自己选的,这个时候就可以利用点值自选的性质,降低复杂度至\(O(n)\)

例题:The Sum of the k-th Powers - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

题目描述

题目描述:求 \((\sum_{i=1}^ni^k) \bmod (10^9+7)\)

数据范围:\(1 \le n \le 10 ^ 9, 0 \le k \le 10 ^ 6\)

算法描述

这个题又叫“自然数幂前缀和”,有广泛的应用,看到自然数的\(k\)次方求前缀和,猜想它是一个\(k + 1\)次多项式(一般几项加起来不行!),下面进行证明。

我们使用归纳法,假设

\[f_{n,k} = \sum_{i = 1}^{n}i^k \]

并假设\(f_{n,1}\)\(f_{n,k - 2}\)已经都满足这个规律。

转化一下式子:

\[(n + 1)^k - n^k\\ =\sum_{i = 0}^{k}\binom{k}{i}n^i \ \ \ -\ \ \ n^k\\ = \sum_{i = 0}^{k - 1}\binom kin^i\\ \]

\[(n + 1)^k - 1^k = (n + 1)^k - n^k + n^k - (n - 1)^k + ... + 2^k - 1^k \]

\[= \sum_{j = 1}^n\sum_{i = 0}^{k - 1}\binom kij^i\\ \]

\[= \sum_{i = 0}^{k - 1}\binom ki\sum_{j = 1}^nj^i \]

注意到\(\sum_{j = 1}^nj^i\)其实就是\(f_{n,i}\),所以

\[(n + 1)^k - 1 = \sum_{i = 0}^{k - 1}\binom kif_{n,i}\\ \]

\[= \sum_{i = 0}^{k - 2}\binom kif_{n,i}\ \ \ + \ \ \ kf_{n,k - 1}\\ \]

\[(n + 1)^k - 1 - \sum_{i = 0}^{k - 2}\binom kif_{n,i} = f_{n,k - 1} \]

由归纳我们可以知道,\(f_{n,k - 2}\)不超过\(k - 1\)次,所以左边等式最高次是\(k\)次,右边也是。

由于\(f_{n,1}\)是一个二次多项式(高斯公式求和从1加到\(n\)),这样,我们就证得了\(f_{n,k - 1}\)是一个\(k\)次多项式。

所以\(\sum_{i = 1}^ni^k\)是一个\(k + 1\)次多项式。

知道了这点,由于\(n\)很大,\(k\)较小,我们求出\(k + 2\)个连续点值然后带进去拉格朗日插值求出答案即可。

然而我们怎样在\(O(k)\)的时间内完成插值呢?

考虑到点值连续,它们互相之间的差就是连续的,这里假设它们是\(1 \to k + 2\),代入式子,有:

\[f(x) = \sum_{i = 1}^{k + 2}y_i\prod_{i \neq j}\frac {(x - x_j)}{(x_i - x_j)}\\ \]

\[= \sum_{i = 1}^{k + 2}y_i\prod_{i \neq j}\frac {(x - j)}{(i - j)}\\ \]

\[= \sum_{i = 1}^{k + 2}y_i\frac {\prod_{i \neq j}(x - j)}{(i - 1)!(k + 2 - i)!(-1)^{k + 2 - i}}\\ \]

\[= \sum_{i = 1}^{k + 2}y_i\frac {\prod_{j = 1}^{k + 2}(x - j)}{(i - 1)!(k + 2 - i)!(-1)^{k + 2 - i}(x - i)}\\ \]

\(k + 2\)以内的阶乘及其逆元和\(\prod_{j = 1}^{k + 2}(x - j)\)都是可以通过预处理\(O(k)\)处理出来,在这里直接\(O(1)\)计算后面分式中的值即可。

时间复杂度\(O(k)\)

Code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 5;
ll n,m,a[N],k,x[N],y[N],s[N],Pd = 1;
const int MOD = 1e9 + 7;
inline ll ksm(ll base,ll pts)
{
	ll ret = 1;
	for(;pts > 0;pts >>= 1,base = base * base % MOD)
		if(pts & 1)
			ret = ret * base % MOD;
	return ret;
}
inline ll f(ll X)
{
	ll ret = 0;
	for(int i = 1;i <= k + 2;i++)
	{
		ll add = y[i];
		add = add * Pd % MOD * ksm(X - x[i],MOD - 2) % MOD;
		add = add * ((((k + 2 - i) % 2 == 0) ? 1 : -1) + MOD) % MOD * ksm(s[i - 1],MOD - 2) % MOD * ksm(s[k + 2 - i],MOD - 2) % MOD;
		ret = (ret + add) % MOD;
	}
	return ret;
}
int main()
{
	cin>>n>>k;
	s[0] = 1;
	for(int i = 1;i <= k + 2;i++) s[i] = s[i - 1] * i % MOD;
	for(int i = 1;i <= k + 2;i++)
	{
		x[i] = i;
		y[i] = ksm(i,k);
		y[i] = (y[i] + y[i - 1]) % MOD;
		Pd = Pd * ((n - x[i] + MOD) % MOD) % MOD;
	}
	if(n <= k + 2) cout<<y[n];
	else cout<<f(n);
	return 0;
}
posted @ 2023-06-07 08:05  The_Last_Candy  阅读(15)  评论(0编辑  收藏  举报