为了能到远方,脚下的每一步都不能|

Aurora-JC

园龄:3年粉丝:3关注:4

2023-05-20 18:41阅读: 335评论: 0推荐: 0

【学习笔记】(8) 拉格朗日插值

拉格朗日插值

首先一个定理:

n 个点(横坐标不同)唯一确定一个最高 n1 次的多项式。

那么, n 个点的点值 (xi,yi) 可以唯一确定一个 n1 次多项式(为了叙述方便,本文中所有“ k 次多项式”“ k 次函数”的最高次项系数可以为 0)。

拉格朗日插值就是用来求这个多项式的。

例如,我们已知四个点值 (1,1)(0,2)(0.5,1.375)(1,1) ,要求过这四个点的三次函数 f

当然,你可以直接待定系数用高斯消元解方程,但那是 O(n3) 的,拉格朗日插值可以在 O(n2) 内求解。

约瑟夫·拉格朗日认为这个函数可以用四个三次函数线性组合出来。

首先构造一个三次函数 f1
,在 x=1 的取值为 1,但在其他三个点的取值为 0

类似地,构造 f2,3,4 依次在每个点取值为 1,在其他三个点取值为 0


画到一张图里就是这样:

那么这几个函数有啥用呢?

容易发现,f(x)=y1f1(x)+y2f2(x)+y3f3(x)+y4f4(x),把那四个点点值代入进去就可以知道。

现在问题就转化为了怎么求 f1,2,3,4

推导

我们可以构造函数:

f1(x)=(xx2)(xx3)(xx4)(x1x2)(x1x3)(x1x4)

f2(x)=(xx1)(xx3)(xx4)(x2x1)(x2x3)(x2x4)

f3(x)=(xx1)(xx2)(xx4)(x3x1)(x3x2)(x3x4)

f4(x)=(xx1)(xx2)(xx3)(x4x1)(x4x2)(x4x3)

把值回代,显然符合

那对于 n 个点值求 n1 次多项式的问题,我们先有点值 (xj,yj)1jn,设函数 fi1in,它们是 n1 次函数,且满足:

fi(xj)={1i=j0ij

则根据上面的构造函数,我们可以写成:

fi(x)=j=1,jin(xxj)(xixj)

最终得到:

f(x)=i=1nyifi(x)

P4781 【模板】拉格朗日插值

套上面公式直接算即可。

如果你在计算每个 xxjxixj 的时候都求一遍逆元,会导致复杂度变为 O(n2logn),多带一个逆元的 log。为了让复杂度瓶颈不在逆元上,我们通常分开算分子分母,在每个函数算完后再进行有理数取模,这样的复杂度为 O(n2)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 998244353, N = 2e3 + 5;
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
int n, k, ans;
int x[N], y[N];
int qsm(int a, int b){
int res = 1;
for(; b; b >>= 1, a = a * a % mod) if(b & 1) res = res * a % mod;
return res;
}
int inv(int x){return qsm(x, mod - 2);}
signed main(){
n = read(), k = read();
for(int i = 1; i <= n; ++i) x[i] = read(), y[i] = read();
for(int i = 1; i <= n; ++i){
int a = y[i], b = 1;
for(int j = 1; j <=n; ++j){
if(i != j){
a = a * (k - x[j]) % mod;
b = b * (x[i] - x[j]) % mod;
}
}
ans = (ans + a * inv(b) % mod + mod) % mod;
}
printf("%lld\n", ans);
return 0;
}

连续点值的插值

如果已知的点值是连续点的点值,我们可以做到 O(n) 的插值。

有时候发现了一个函数是 n 次多项式,就求 n+1 个点值进去插值。为了省事,这里我们令 xi=i1in+1。注意这里 n 与上面意义不一样,是次数而不是点数。

我们有拉插公式:

f(x)=i=1n+1yij=1,jin+1xxjxixj

代入 xi=i

f(x)=i=1n+1yij=1,jin+1xjij

考虑怎么快速求 j=1,jin+1xjij

上述式子的分子是:

j=1n+1(xj)xi

分母的话把 ij 累乘拆成两段阶乘,就是:

(1)n+1i(i1)!(n+1i)!

于是连续点值的插值公式:

f(x)=i=1n+1yij=1n+1(xj)(xi)(1)n+1i(i1)!(n+1i)!

预处理前后缀积、阶乘阶乘逆然后代这个式子的复杂度为 O(n)

CF622F The Sum of the k-th Powers

可以发现答案是 k+1 次多项式,因此代 k+2 个点进去拉插。

证明 : https://www.luogu.com.cn/blog/formkiller/cf622f-the-sum-of-the-k-th-powers-ti-xie#

本题的 ik 可以线性筛,因此复杂度可以做到 O(n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 1e9 + 7, N = 1e6 + 5;
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
int n, k, cnt, ans;
int v[N], inv[N], fac[N], infac[N], prime[N];
int pre[N], suf[N], f[N];
int qsm(int a, int b){
int res = 1;
for(; b; b >>= 1, a = a * a % mod) if(b & 1) res = res * a % mod;
return res;
}
void solve(int nn){
f[1] = 1;
for(int i = 2; i <= nn; ++i){
if(!v[i]) v[i] = i, prime[++cnt] = i, f[i] = qsm(i, k);
for(int j = 1; j <= cnt; ++j){
if(prime[j] > v[i] || prime[j] > nn / i) break;
v[i * prime[j]] = prime[j], f[i * prime[j]] = f[i] * f[prime[j]] % mod;
}
}
for(int i = 2; i <= nn; ++i) (f[i] += f[i - 1]) %= mod;
return ;
}
signed main(){
n = read(), k = read();
solve(k + 2);
if(n <= k + 2) return printf("%lld\n", f[n]), 0;
pre[0] = suf[k + 3] = 1;
for(int i = 1; i <= k + 2; ++i) pre[i] = pre[i - 1] * (n - i) % mod;
for(int i = k + 2; i; --i) suf[i] = suf[i + 1] * (n - i) % mod;
infac[0] = infac[1] = inv[0] = fac[0] = inv[1] = fac[1] = 1;
for(int i = 2; i <= k + 2; ++i){
fac[i] = fac[i - 1] * i % mod;
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
for(int i = 2; i <= k + 2; ++i) infac[i] = infac[i - 1] * inv[i] % mod;
for(int i = 1; i <= k + 2; ++i){
int p = pre[i - 1] * suf[i + 1] % mod;
int q = infac[i - 1] * infac[k + 2 - i] % mod;
int mul = ((k + 2 - i) & 1) ? -1 : 1;
ans = (ans + (q * mul + mod) % mod * p % mod * f[i] % mod) % mod;
}
printf("%lld\n", ans);
return 0;
}

P4593 [TJOI2018]教科书般的亵渎

首先,认真读题不难发现若血量的区间 [1,mi] 连续,则只需要一张亵渎就可以杀死区间 [1,mi] 内所有怪物,所以 k=m+1

考虑到这点,我们就可以轻松的写出式子(保证 ai 升序):

定义 a0=0,有

ans=i=1m+1(j=1nai1jm+1j=im(ajai1)m+1)

然后就跟上题一样了,发现 j=1nai1jm+1 是一个 m+2 次的多项式。

由于取值是连续的,就可以优化到 O(m),对于 j=im(ajai1)m+1 直接暴力求解就行了。

时间复杂度 O(m2logm)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 1e9 + 7, N = 5e1 + 7;
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
int T, n, m, k, cnt, ans;
int v[N], inv[N], fac[N], infac[N], prime[N];
int pre[N], suf[N], f[N], a[N];
int qsm(int a, int b){
int res = 1;
for(; b; b >>= 1, a = a * a % mod) if(b & 1) res = res * a % mod;
return res;
}
void solve(int nn){
memset(v, 0, sizeof(v));
f[1] = 1, cnt = 0;
for(int i = 2; i <= nn; ++i){
if(!v[i]) v[i] = i, prime[++cnt] = i, f[i] = qsm(i, k);
for(int j = 1; j <= cnt; ++j){
if(prime[j] > v[i] || prime[j] > nn / i) break;
v[i * prime[j]] = prime[j], f[i * prime[j]] = f[i] * f[prime[j]] % mod;
}
}
for(int i = 2; i <= nn; ++i) (f[i] += f[i - 1]) %= mod;
return ;
}
int lagrange(int x){
if(x <= m + 3) return f[x];
int res = 0; suf[m + 4] = pre[0] = 1;
for(int i = 1; i <= m + 3; ++i) pre[i] = pre[i - 1] * (x - i) % mod;
for(int i = m + 3; i; --i) suf[i] = suf[i + 1] * (x - i) % mod;
for(int i = 1; i <= m + 3; ++i)
res = (res + (((m + 3 - i) & 1) ? -1ll : 1ll) * (pre[i - 1] * suf[i + 1] % mod * infac[i - 1] % mod * infac[m + 3 - i] % mod * f[i] % mod) + mod) % mod;
return res;
}
signed main(){
infac[0] = infac[1] = inv[0] = fac[0] = inv[1] = fac[1] = 1;
for(int i = 2; i <= 54; ++i){
fac[i] = fac[i - 1] * i % mod;
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
for(int i = 2; i <= 54; ++i) infac[i] = infac[i - 1] * inv[i] % mod;
T = read();
while(T--){
ans = 0;
n = read(), m = read();
for(int i = 1; i <= m; ++i) a[i] = read();
sort(a + 1, a + 1 + m); k = m + 1, solve(m + 3);
for(int i = 1; i <= m + 1; ++i){
ans = (ans + lagrange(n - a[i - 1])) % mod;
for(int j = i; j <= m; ++j) ans = (ans - qsm(a[j] - a[i - 1], k) + mod) % mod;
}
printf("%lld\n", ans);
}
return 0;
}

重心拉格朗日插值

如果每次加入一个数重新求多项式的插值,每次都是 O(n2) ,不优。

重心拉格朗日插值法可以在加入一个数后 O(n) 求出新的多项式的插值

f(x)=i=1nyij=1,jinxxjxixj

f(x)=i=1nyij=1n(xxj)(xxi)j=1,jin(xixj)

g=j=1n(xxj),w(i)=j=1,jin(xixj),那么有:

f(x)=gi=1nyi(xxi)w(i)

那么对于每一个新增加的插值点 ,我们可以 O(n) 的更新所有的 w(i), 求原函数仍然是 O(n2)

参考资料: 拉格朗日插值学习笔记 拉格朗日插值学习笔记

本文作者:南风未起

本文链接:https://www.cnblogs.com/jiangchen4122/p/17417626.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   Aurora-JC  阅读(335)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起