【数学】拉格朗日插值
【数学】拉格朗日插值
题目描述
由小学知识可知 \(n\) 个点 \((x_i,y_i)\) 可以唯一地确定一个多项式 \(y = f(x)\)。
现在,给定这 \(n\) 个点,请你确定这个多项式,并求出 \(f(k) \bmod 998244353\) 的值。
\(1 \le n \leq 2\times 10^3\),\(1 \le x_i,y_i,k < 998244353\),\(x_i\) 两两不同。
算法描述
考虑构造一个多项式,满足\(\forall x_i,f(x_i) = y_i\)。考虑用一种手法,使若干项加起来,当\(x = x_i\)时,其他项全部是0,只有一项是\(y_i\)。即:
其中
对于\(g_i(x)\)只有在\(x_i\)这个值时才等于1这个性质,我们考虑用乘法来实现,即:
这样就能做到当\(x = x_i\)时它不为0,\(x = x_j(j \neq i)\)时为0(其他值不做要求)
那么怎样才能让\(x = x_i\)时等于1呢?观察到\(x = x_i\)时这个值等于
所以除掉一个相同的值就好了。所以
时间复杂度\(O(n^2)\)。
对于这道题\(n \leq 2 \times 10^3\)的情况(也是一般情况),直接做即可。
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MOD = 998244353;
inline ll ksm(ll base,ll pts)
{
ll ret = 1;
for(;pts > 0;pts >>= 1,base = base * base % MOD)
if(pts & 1)
ret = ret * base % MOD;
return ret;
}
ll n,k,x[2005],y[2005];
inline ll g(ll i,ll X)
{
ll ret = 1;
for(int j = 1;j <= n;j++)
{
if(i == j) continue;
ret = ret * (X - x[j] + MOD) % MOD;
}
for(int j = 1;j <= n;j++)
{
if(i == j) continue;
ret = ret * ksm((x[i] - x[j] + MOD) % MOD,MOD - 2) % MOD;
}
return ret;
}
int main()
{
cin>>n>>k;
for(int i = 1;i <= n;i++)
cin>>x[i]>>y[i];
ll ans = 0;
for(int i = 1;i <= n;i++)
ans += y[i] * g(i,k) % MOD,ans %= MOD;
cout<<ans;
return 0;
}
但是对于很多更加灵活的题目,时间复杂度要求更高,并且点值是自己选的,这个时候就可以利用点值自选的性质,降低复杂度至\(O(n)\)。
例题:The Sum of the k-th Powers - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
题目描述
题目描述:求 \((\sum_{i=1}^ni^k) \bmod (10^9+7)\)。
数据范围:\(1 \le n \le 10 ^ 9, 0 \le k \le 10 ^ 6\)。
算法描述
这个题又叫“自然数幂前缀和”,有广泛的应用,看到自然数的\(k\)次方求前缀和,猜想它是一个\(k + 1\)次多项式(一般几项加起来不行!),下面进行证明。
我们使用归纳法,假设
并假设\(f_{n,1}\)到\(f_{n,k - 2}\)已经都满足这个规律。
转化一下式子:
注意到\(\sum_{j = 1}^nj^i\)其实就是\(f_{n,i}\),所以
由归纳我们可以知道,\(f_{n,k - 2}\)不超过\(k - 1\)次,所以左边等式最高次是\(k\)次,右边也是。
由于\(f_{n,1}\)是一个二次多项式(高斯公式求和从1加到\(n\)),这样,我们就证得了\(f_{n,k - 1}\)是一个\(k\)次多项式。
所以\(\sum_{i = 1}^ni^k\)是一个\(k + 1\)次多项式。
知道了这点,由于\(n\)很大,\(k\)较小,我们求出\(k + 2\)个连续点值然后带进去拉格朗日插值求出答案即可。
然而我们怎样在\(O(k)\)的时间内完成插值呢?
考虑到点值连续,它们互相之间的差就是连续的,这里假设它们是\(1 \to k + 2\),代入式子,有:
而\(k + 2\)以内的阶乘及其逆元和\(\prod_{j = 1}^{k + 2}(x - j)\)都是可以通过预处理\(O(k)\)处理出来,在这里直接\(O(1)\)计算后面分式中的值即可。
时间复杂度\(O(k)\)
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 5;
ll n,m,a[N],k,x[N],y[N],s[N],Pd = 1;
const int MOD = 1e9 + 7;
inline ll ksm(ll base,ll pts)
{
ll ret = 1;
for(;pts > 0;pts >>= 1,base = base * base % MOD)
if(pts & 1)
ret = ret * base % MOD;
return ret;
}
inline ll f(ll X)
{
ll ret = 0;
for(int i = 1;i <= k + 2;i++)
{
ll add = y[i];
add = add * Pd % MOD * ksm(X - x[i],MOD - 2) % MOD;
add = add * ((((k + 2 - i) % 2 == 0) ? 1 : -1) + MOD) % MOD * ksm(s[i - 1],MOD - 2) % MOD * ksm(s[k + 2 - i],MOD - 2) % MOD;
ret = (ret + add) % MOD;
}
return ret;
}
int main()
{
cin>>n>>k;
s[0] = 1;
for(int i = 1;i <= k + 2;i++) s[i] = s[i - 1] * i % MOD;
for(int i = 1;i <= k + 2;i++)
{
x[i] = i;
y[i] = ksm(i,k);
y[i] = (y[i] + y[i - 1]) % MOD;
Pd = Pd * ((n - x[i] + MOD) % MOD) % MOD;
}
if(n <= k + 2) cout<<y[n];
else cout<<f(n);
return 0;
}