关于拉格朗日插值

拉格朗日插值法

众所周知,\(n + 1\)\(x\) 坐标不同的点可以确定唯一的最高为 \(n\) 次的多项式。在算法竞赛中,我们常常会碰到一类题目,题目中直接或间接的给出了 \(n+1\) 个点,让我们求由这些点构成的多项式在某一位置的取值。

一个最显然的思路就是直接高斯消元求出多项式的系数,但是这样做复杂度巨大 \((n^3)\) 且根据算法实现不同往往会存在精度问题

而拉格朗日插值法可以在 \(n^2\) 的复杂度内完美解决上述问题

假设该多项式为 \(f(x)\), 第 \(i\) 个点的坐标为 \((x_i, y_i)\),我们需要找到该多项式在 \(k\) 点的取值

根据拉格朗日插值法

\[f(k) = \sum_{i = 0}^{n} y_i \prod_{i \not = j} \frac{k - x[j]}{x[i] - x[j]} \]

inline int Lagrange(int t){
	int res = 0;
	for(int i = 1; i <= n; i++){
		int tmp = y[i];
		for(int j = 1; j <= n; j++){
			if(i != j) tmp = 1ll * tmp * (x[j] - t)%md * pwr(x[j] - x[i], md - 2)%md;
		}
		res = (res + tmp)%md;
	}
	return (res + md)%md;
}

\(x\) 取值连续时的做法

在绝大多数题目中我们需要用到的 \(x_i\) 的取值都是连续的,这样的话我们可以把上面的算法优化到 \(O(n)\) 复杂度

首先把 \(x_i\) 换成 \(i\),新的式子为:

\[f(k) = \sum_{i=0}^n y_i \prod_{i \not = j} \frac{k - j}{i - j} \]

考虑如何快速计算 \(\prod_{i \not = j} \frac{k - j}{i - j}\)

对于分子来说,我们维护出关于 \(k\) 的前缀积和后缀积,也就是

\[pre_i = \prod_{j = 0}^{i} k - j \]

\[suf_i = \prod_{j = i}^n k - j \]

对于分母来说,观察发现这其实就是阶乘的形式,我们用 \(fac[i]\) 来表示 \(i!\)

那么式子就变成了

\[f(k) = \sum_{i=0}^n y_i \frac{pre_{i-1} * suf_{i+1}}{fac[i - 1] * fac[N - i]} \]

注意:分母可能会出现符号问题,也就是说,当 \(N - i\) 为奇数时,分母应该取负号

inline int Lagrange(int *f,int len,int x){
	int res = 0; pre[0] = x%md; suf[len + 1] = 1;
	for(int i = 1; i <= len; i++) pre[i] = 1ll * pre[i - 1] * (x - i)%md;
	for(int i = len; i >= 0; i--) suf[i] = 1ll * suf[i + 1] * (x - i)%md;
	for(int i = 0; i <= len; i++){
		int tmp = 1ll * f[i] * ifac[i]%md * ifac[len - i]%md * ((len - i)&1? md - 1 : 1)%md;
		if(i) tmp = 1ll * tmp * pre[i - 1]%md;
		if(i != len) tmp = 1ll * tmp * suf[i + 1]%md;
		res = (res + tmp)%md;
	}
	return res;
}

拉格朗日插系数

\[f(z)=∑_{i=1}^n\limits y_i\frac{∏_{j≠i}(z−x_j)}{∏_{j≠i}(x_i−x_j)} \]

首先提出常数部分:

\[a_i=\frac{y_i}{∏_{j≠i}(x_i−x_j)} \]

可以 \(O(n^2)\) 搞出每一个 \(a_i\)

然后求一个多项式 \(g(z)=∏^n_{i=1}\limits(z−x_i)\)

可以发现

\[f(z)=∑_{i=1}^n\limits a_i\frac{g(z)}{z−x_i} \]

考虑如何快速搞出后面那个 \(\frac{g(z)}{z−x_i}\)

\(h(z)=\frac{g(z)}{z−c}\)

可以得到 \((z−c)h(z)=g(z)\)。两边提取系数得到

\[[z^{i−1}]h−c[z^i]h=[z^i]g\\ [z^i]h=\frac{[z^i]g−[z^{i−1}]h}{−c} \]

递推即可。

inline vector<int> Langrange(const vector<int> &x,const vector<int> &y){
	int n = x.size();
	vector<int> a(n),b(n+1),c(n+1),f(n);
	for(int i = 0; i < n; i++){
		int tmp = 1;
		for(int j = 0; j < n; j++){
			if(i != j) tmp = 1ll * tmp * (x[i] - x[j] + md)%md;
		}
		a[i] = 1ll * y[i] * pwr(tmp,md - 2)%md;
	}
	b[0] = 1;
	for(int i = 0; i < n; i++){
		for(int j = n; j; j--) b[j] = (1ll * (md - x[i]) * b[j] + b[j - 1])%md;
		b[0] = 1ll * b[0] * (md - x[i])%md;
	}
	for(int i = 0; i < n; i++){
		int iv = pwr(md - x[i], md - 2);
		if(!iv)for(int j=0;j<n;j++)c[i]=b[j+1];
		else {
			c[0]=1ll*iv*b[0]%md;
			for(int j = 1; j < n; j++) c[j] = 1ll * (b[j] - c[j - 1] + md) * iv%md;
		}
		for(int j = 0; j < n; j++) f[j] = (f[j] + 1ll * a[i] * c[j])%md;
	}
	return f;
}
inline int calc(vector<int> &f,int x){
	int res = 0;
	for(int i = f.size() - 1; ~i; i--) res = (1ll * res * x + f[i])%md;
	return res;
}

【模板】拉格朗日插值

点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n,k;
long long x[2005],y[2005];
const long long md=998244353;
inline long long pwr(long long x,long long y){
	long long res=1;
	while(y){
		if(y&1)res=res*x%md;
		x=x*x%md;y>>=1;
	}return res;
}
inline long long Lagrange(int t){
	long long res=0;
	for(int i=1;i<=n;i++){
		long long tmp=y[i];
		for(int j=1;j<=n;j++){
			if(i!=j)tmp=tmp*(x[j]-t)%md*pwr(x[j]-x[i],md-2)%md;
		}res=(res+tmp)%md;
	}return (res+md)%md;
}
int main(){
	scanf("%d%d",&n,&k);
	for(int i=1;i<=n;i++)scanf("%lld%lld",&x[i],&y[i]);
	printf("%lld",Lagrange(k));

	return 0;
}



拉格朗日插值2

点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n,m;
const int md=998244353,G=3,Gi=(md+1)/3;
int r[1<<20],lim;
inline int pwr(int x,int y){
	int res=1;
	while(y){
		if(y&1)res=1ll*res*x%md;
		x=1ll*x*x%md;y>>=1;
	}
	return res;
}
inline void NTT(int *dp,int W){
	for(int i=0;i<(1<<lim);i++)if(i<r[i])swap(dp[i],dp[r[i]]);
	for(int i=0;i<lim;i++){
		int w=pwr(W,(md-1)/(1<<(i+1)));
		for(int j=0;j<(1<<lim);j+=(1<<(i+1))){
			int Pw=1;
			for(int t=0;t<(1<<i);t++){
				int x=dp[j+t],y=1ll*dp[j+(1<<i)+t]*Pw%md;
				dp[j+t]=(x+y)%md;dp[j+(1<<i)+t]=(x-y+md)%md;Pw=1ll*Pw*w%md;
			}
		}
	}
}
int fac[200005],inv[200005],ifac[200005];
inline void init(){
	while((1<<lim)<=2*n)lim++;
	for(int i=0;i<(1<<lim);i++)r[i]=(r[i>>1]>>1)+((i&1)<<(lim-1));
	fac[0]=fac[1]=inv[0]=inv[1]=ifac[0]=ifac[1]=1;
	for(int i=2;i<=n;i++)fac[i]=1ll*fac[i-1]*i%md;
	for(int i=2;i<=n;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
	for(int i=2;i<=n;i++)ifac[i]=1ll*inv[i]*ifac[i-1]%md;
}
int f[1<<20],g[1<<20],ltmp[400005],rtmp[400005];
int main(){
	scanf("%d%d",&n,&m);init();
	for(int i=0;i<=n;i++)scanf("%d",&f[i]);
	for(int i=0;i<=n;i++)f[i]=1ll*f[i]*ifac[i]%md*ifac[n-i]%md;
	for(int i=0;i<=n;i++)if((n-i)&1)f[i]=(md-f[i])%md;
	for(int i=0;i<=2*n;i++)g[i]=pwr(m-n+i,md-2);
	NTT(f,G);NTT(g,G);int INV=pwr((1<<lim),md-2);
	for(int i=0;i<(1<<lim);i++)f[i]=1ll*f[i]*g[i]%md;
	NTT(f,Gi);
	for(int i=0;i<(1<<lim);i++)f[i]=1ll*f[i]*INV%md;
	ltmp[0]=1;for(int i=1;i<=n;i++)ltmp[i]=1ll*ltmp[i-1]*(m-i)%md;
	rtmp[0]=m;for(int i=1;i<=n;i++)rtmp[i]=1ll*rtmp[i-1]*(m+i)%md;
	for(int i=0;i<=n;i++)f[i+n]=1ll*f[i+n]*ltmp[n-i]%md*rtmp[i]%md;
	for(int i=0;i<=n;i++)printf("%d ",f[i+n]);puts("");

	return 0;
}





[集训队互测 2012] calc

点击查看代码
#include<bits/stdc++.h>
using namespace std;
int k,n,md;
int fac[1005],inv[1005],ifac[1005];
inline void init(){
	fac[0]=fac[1]=inv[0]=inv[1]=ifac[0]=ifac[1]=1;
	for(int i=2;i<=2*n+1;i++)fac[i]=1ll*fac[i-1]*i%md;
	for(int i=2;i<=2*n+1;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
	for(int i=2;i<=2*n+1;i++)ifac[i]=1ll*ifac[i-1]*inv[i]%md;
}
int pre[1005],suf[1005];
inline int lag(int *f,int len,int x){
	int res=0;pre[0]=x%md;suf[len+1]=1;
	for(int i=1;i<=len;i++)pre[i]=1ll*pre[i-1]*(x-i)%md;
	for(int i=len;i>=0;i--)suf[i]=1ll*suf[i+1]*(x-i)%md;
	for(int i=0;i<=len;i++){
		int tmp=1ll*f[i]*ifac[i]%md*ifac[len-i]%md*((len-i)&1?md-1:1)%md;
		if(i)tmp=1ll*tmp*pre[i-1]%md;
		if(i!=len)tmp=1ll*tmp*suf[i+1]%md;
		res=(res+tmp)%md;
	}
	return res;
}
int dp[505][1005];
int main(){
	scanf("%d%d%d",&k,&n,&md);
	init();
	dp[0][0]=1;
	for(int i=1;i<=n;i++){
		int sum=dp[i-1][0];
		for(int j=1;j<=2*n+1;j++){
			dp[i][j]=1ll*j*sum%md;
			sum=(sum+dp[i-1][j])%md;
		}
	}
	for(int i=1;i<=2*n+1;i++)dp[n][i]=(dp[n][i]+dp[n][i-1])%md;
	dp[n][0]=lag(dp[n],2*n+1,k);
	printf("%lld",(1ll*dp[n][0]*fac[n]%md+md)%md);


	return 0;
}





[TJOI2018]教科书般的亵渎

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int md=1e9+7;
int T;
long long n;
int m;
int a[55],k,f[55],pre[55],suf[55],fac[55];
map<int,int> ma;
inline int pwr(int x,int y){
    int res=1;
    while(y){
        if(y&1)res=1ll*res*x%md;
        x=1ll*x*x%md;y>>=1;
    }
    return res;
}
inline int lag(long long x){
    if(x<=k+2)return f[x];
    int res=0;pre[0]=1;suf[k+3]=1;x%=md;
    for(int i=1;i<=k+2;i++)pre[i]=1ll*pre[i-1]*(x-i)%md;
    for(int i=k+2;i>=1;i--)suf[i]=1ll*suf[i+1]*(x-i)%md;
    for(int i=1;i<=k+2;i++){
        int x=1ll*pre[i-1]*suf[i+1]%md;
        int fu=((k+2-i)&1)?-1:1;
        int y=1ll*fac[i-1]*fac[k+2-i]%md*fu%md;
        res=(res+1ll*f[i]*x%md*pwr(y,md-2)%md)%md;
    }
    return (res+md)%md;
}
inline void init(){
    fac[0]=1;
    for(int i=1;i<=52;i++)fac[i]=1ll*fac[i-1]*i%md;
}
int main(){
    scanf("%d",&T);init();
	while(T--){
        scanf("%lld%d",&n,&m);k=m+1;
        ma.clear();
        for(int i=1;i<=m;i++)scanf("%d",&a[i]),ma[a[i]]=1;
        sort(a+1,a+m+1);
        while(ma[n])n--,k--,m--;
        for(int i=1;i<=k+2;i++)f[i]=(f[i-1]+pwr(i,k))%md;
        int res=lag(n);
        for(int i=1;i<=m;i++)res=(res-pwr(a[i],k))%md;
        for(int i=1;i<=m;i++)res=(res+lag(n-a[i]))%md;
        for(int i=1;i<=m;i++){
            for(int j=i-1;j>=1;j--){
				res=(res-pwr(a[i]-a[j],k))%md;
			}
		}
        printf("%d\n",(res+md)%md);
    }
    return 0;
}



[NOI2019] 机器人

点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n;
const int md=1e9+7;
namespace Lagrange{
	int pre1[305],suf1[305],pre2[305],suf2[305],inv[305];
	inline void init(){
		inv[0]=inv[1]=1;
		for(int i=2;i<=n;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
	}
	inline int lag(int *y,int k,int t){
		if(t<=k)return y[t];
		pre1[0]=suf1[0]=suf2[k+1]=1;pre2[0]=t%md;
		for(int i=1;i<=k;i++)pre1[i]=1ll*pre1[i-1]*inv[i]%md;
		for(int i=1;i<=k;i++)suf1[i]=1ll*suf1[i-1]*-inv[i]%md;
		for(int i=1;i<=k;i++)pre2[i]=1ll*pre2[i-1]*(t-i)%md;
		for(int i=k;i>=1;i--)suf2[i]=1ll*suf2[i+1]*(t-i)%md;
		int res=0;
		for(int i=0;i<=k;i++)res=(res+1ll*y[i]*pre1[i]%md*suf1[k-i]%md*(i?pre2[i-1]:1)%md*suf2[i+1])%md;
		return (res+md)%md;
	}
}
using Lagrange::lag;
int A[305],B[305],lim;
int vis[305][305],cnt,L[3005],R[3005],dp[3005][10005];
void build(int l,int r){
	if(l>r||vis[l][r])return ;
	vis[l][r]=++cnt;L[cnt]=l;R[cnt]=r;
	if(l==r)return ;
	for(int i=l;i<=r;i++){
		if(abs((r-i)-(i-l))>2)continue;
		build(l,i-1);build(i+1,r);
	}
}
bool used[305][305];
void solve(int l,int r,int len,int v){
	if(l>r||used[l][r])return ;
	int id=vis[l][r];used[l][r]=1;
	for(int i=1;i<=len;i++)dp[id][i]=0;
	for(int i=l;i<=r;i++){
		if(abs((r-i)-(i-l))>2||A[i]>v||B[i]<=v)continue;
		solve(l,i-1,len,v);solve(i+1,r,len,v);
		for(int j=1;j<=len;j++)dp[id][j]=(dp[id][j]+1ll*dp[vis[l][i-1]][j]*dp[vis[i+1][r]][j-1])%md;
	}
	for(int i=1;i<=len;i++)dp[id][i]=(dp[id][i]+dp[id][i-1])%md;
}
inline void Getval(int l,int r){
	for(int i=1;i<=cnt;i++)dp[i][0]=lag(dp[i],R[i]-L[i]+1,r-l+1);
	for(int l=1;l<=n;l++)for(int r=l;r<=n;r++)used[l][r]=0;
}
vector<int> hsh;
int main(){
	scanf("%d",&n);
	Lagrange::init();
	for(int i=1;i<=n;i++)scanf("%d%d",&A[i],&B[i]),++B[i];
	for(int i=1;i<=n;i++)hsh.push_back(A[i]);
	for(int i=1;i<=n;i++)hsh.push_back(B[i]);
	sort(hsh.begin(),hsh.end());
	hsh.erase(unique(hsh.begin(),hsh.end()),hsh.end());
	for(int i=1;i<=n;i++)A[i]=upper_bound(hsh.begin(),hsh.end(),A[i])-hsh.begin();
	for(int i=1;i<=n;i++)B[i]=upper_bound(hsh.begin(),hsh.end(),B[i])-hsh.begin();
	for(int i=0;i<=n;i++)dp[0][i]=1;
	build(1,n);
	for(int i=0;i+1<hsh.size();i++){
		for(int l=1;l<=n;l++){
			for(int r=l;r<=n;r++)if(vis[l][r])solve(l,r,min(hsh[i+1]-hsh[i],n+1),i+1);
		}
		Getval(hsh[i],hsh[i+1]-1);
	}

	printf("%d\n",dp[vis[1][n]][0]);

	return 0;
}


[APIO2016]划艇

点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n;
int a[505],b[505];
vector<int> hsh;
const int md=1e9+7;
int fac[505],inv[505],ifac[505];
inline void init(){
	fac[0]=fac[1]=inv[0]=inv[1]=ifac[0]=ifac[1]=1;
	for(int i=2;i<=n;i++)fac[i]=1ll*fac[i-1]*i%md;
	for(int i=2;i<=n;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
	for(int i=2;i<=n;i++)ifac[i]=1ll*ifac[i-1]*inv[i]%md;
}
int tmp[505],pre[505],suf[505];
inline int lag(int *f,int len,int x){
	if(x<=len)return f[x];
	int res=0;pre[0]=x%md;suf[len+1]=1;
	for(int i=1;i<=len;i++)pre[i]=1ll*pre[i-1]*(x-i)%md;
	for(int i=len;i>=0;i--)suf[i]=1ll*suf[i+1]*(x-i)%md;
	for(int i=0;i<=len;i++){
		int tmp=1ll*f[i]*ifac[i]%md*ifac[len-i]%md*((len-i)&1?md-1:1)%md;
		if(i)tmp=1ll*tmp*pre[i-1]%md;
		if(i!=len)tmp=1ll*tmp*suf[i+1]%md;
		res=(res+tmp)%md;
	}
	return res;
}
int dp[505][505];
int main(){
	scanf("%d",&n);init();
	for(int i=1;i<=n;i++)scanf("%d%d",&a[i],&b[i]),b[i]++;
	for(int i=1;i<=n;i++)hsh.push_back(a[i]);
	for(int i=1;i<=n;i++)hsh.push_back(b[i]);
	sort(hsh.begin(),hsh.end());
	hsh.erase(unique(hsh.begin(),hsh.end()),hsh.end());
	for(int i=1;i<=n;i++)a[i]=upper_bound(hsh.begin(),hsh.end(),a[i])-hsh.begin();
	for(int i=1;i<=n;i++)b[i]=upper_bound(hsh.begin(),hsh.end(),b[i])-hsh.begin();
	for(int i=0;i<=n;i++)dp[i][0]=1;
	for(int t=1;t<hsh.size();t++){
		for(int i=1;i<=n;i++){
			if(a[i]<=t&&b[i]>t){
				int sum=dp[i-1][0];
				for(int j=1;j<=n;j++){
					dp[i][j]=(sum+dp[i-1][j])%md;
					sum=(sum+dp[i-1][j])%md;
				}
			}
			else for(int j=1;j<=n;j++)dp[i][j]=dp[i-1][j];
		}
		for(int i=1;i<=n;i++){
			for(int j=1;j<=i;j++)dp[i][j]=(dp[i][j]+dp[i][j-1])%md;
			dp[i][0]=lag(dp[i],i,hsh[t]-hsh[t-1]);
		}
	}
	printf("%d",(dp[n][0]+md-1)%md);


	return 0;
}





posted @ 2022-06-21 22:18  一粒夸克  阅读(75)  评论(0编辑  收藏  举报