关于拉格朗日插值
拉格朗日插值法
众所周知,\(n + 1\) 个 \(x\) 坐标不同的点可以确定唯一的最高为 \(n\) 次的多项式。在算法竞赛中,我们常常会碰到一类题目,题目中直接或间接的给出了 \(n+1\) 个点,让我们求由这些点构成的多项式在某一位置的取值。
一个最显然的思路就是直接高斯消元求出多项式的系数,但是这样做复杂度巨大 \((n^3)\) 且根据算法实现不同往往会存在精度问题
而拉格朗日插值法可以在 \(n^2\) 的复杂度内完美解决上述问题
假设该多项式为 \(f(x)\), 第 \(i\) 个点的坐标为 \((x_i, y_i)\),我们需要找到该多项式在 \(k\) 点的取值
根据拉格朗日插值法
inline int Lagrange(int t){
int res = 0;
for(int i = 1; i <= n; i++){
int tmp = y[i];
for(int j = 1; j <= n; j++){
if(i != j) tmp = 1ll * tmp * (x[j] - t)%md * pwr(x[j] - x[i], md - 2)%md;
}
res = (res + tmp)%md;
}
return (res + md)%md;
}
在 \(x\) 取值连续时的做法
在绝大多数题目中我们需要用到的 \(x_i\) 的取值都是连续的,这样的话我们可以把上面的算法优化到 \(O(n)\) 复杂度
首先把 \(x_i\) 换成 \(i\),新的式子为:
考虑如何快速计算 \(\prod_{i \not = j} \frac{k - j}{i - j}\)
对于分子来说,我们维护出关于 \(k\) 的前缀积和后缀积,也就是
对于分母来说,观察发现这其实就是阶乘的形式,我们用 \(fac[i]\) 来表示 \(i!\)
那么式子就变成了
注意:分母可能会出现符号问题,也就是说,当 \(N - i\) 为奇数时,分母应该取负号
inline int Lagrange(int *f,int len,int x){
int res = 0; pre[0] = x%md; suf[len + 1] = 1;
for(int i = 1; i <= len; i++) pre[i] = 1ll * pre[i - 1] * (x - i)%md;
for(int i = len; i >= 0; i--) suf[i] = 1ll * suf[i + 1] * (x - i)%md;
for(int i = 0; i <= len; i++){
int tmp = 1ll * f[i] * ifac[i]%md * ifac[len - i]%md * ((len - i)&1? md - 1 : 1)%md;
if(i) tmp = 1ll * tmp * pre[i - 1]%md;
if(i != len) tmp = 1ll * tmp * suf[i + 1]%md;
res = (res + tmp)%md;
}
return res;
}
拉格朗日插系数
首先提出常数部分:
可以 \(O(n^2)\) 搞出每一个 \(a_i\)。
然后求一个多项式 \(g(z)=∏^n_{i=1}\limits(z−x_i)\)。
可以发现
考虑如何快速搞出后面那个 \(\frac{g(z)}{z−x_i}\)。
设 \(h(z)=\frac{g(z)}{z−c}\)。
可以得到 \((z−c)h(z)=g(z)\)。两边提取系数得到
递推即可。
inline vector<int> Langrange(const vector<int> &x,const vector<int> &y){
int n = x.size();
vector<int> a(n),b(n+1),c(n+1),f(n);
for(int i = 0; i < n; i++){
int tmp = 1;
for(int j = 0; j < n; j++){
if(i != j) tmp = 1ll * tmp * (x[i] - x[j] + md)%md;
}
a[i] = 1ll * y[i] * pwr(tmp,md - 2)%md;
}
b[0] = 1;
for(int i = 0; i < n; i++){
for(int j = n; j; j--) b[j] = (1ll * (md - x[i]) * b[j] + b[j - 1])%md;
b[0] = 1ll * b[0] * (md - x[i])%md;
}
for(int i = 0; i < n; i++){
int iv = pwr(md - x[i], md - 2);
if(!iv)for(int j=0;j<n;j++)c[i]=b[j+1];
else {
c[0]=1ll*iv*b[0]%md;
for(int j = 1; j < n; j++) c[j] = 1ll * (b[j] - c[j - 1] + md) * iv%md;
}
for(int j = 0; j < n; j++) f[j] = (f[j] + 1ll * a[i] * c[j])%md;
}
return f;
}
inline int calc(vector<int> &f,int x){
int res = 0;
for(int i = f.size() - 1; ~i; i--) res = (1ll * res * x + f[i])%md;
return res;
}
【模板】拉格朗日插值
点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n,k;
long long x[2005],y[2005];
const long long md=998244353;
inline long long pwr(long long x,long long y){
long long res=1;
while(y){
if(y&1)res=res*x%md;
x=x*x%md;y>>=1;
}return res;
}
inline long long Lagrange(int t){
long long res=0;
for(int i=1;i<=n;i++){
long long tmp=y[i];
for(int j=1;j<=n;j++){
if(i!=j)tmp=tmp*(x[j]-t)%md*pwr(x[j]-x[i],md-2)%md;
}res=(res+tmp)%md;
}return (res+md)%md;
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)scanf("%lld%lld",&x[i],&y[i]);
printf("%lld",Lagrange(k));
return 0;
}
拉格朗日插值2
点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n,m;
const int md=998244353,G=3,Gi=(md+1)/3;
int r[1<<20],lim;
inline int pwr(int x,int y){
int res=1;
while(y){
if(y&1)res=1ll*res*x%md;
x=1ll*x*x%md;y>>=1;
}
return res;
}
inline void NTT(int *dp,int W){
for(int i=0;i<(1<<lim);i++)if(i<r[i])swap(dp[i],dp[r[i]]);
for(int i=0;i<lim;i++){
int w=pwr(W,(md-1)/(1<<(i+1)));
for(int j=0;j<(1<<lim);j+=(1<<(i+1))){
int Pw=1;
for(int t=0;t<(1<<i);t++){
int x=dp[j+t],y=1ll*dp[j+(1<<i)+t]*Pw%md;
dp[j+t]=(x+y)%md;dp[j+(1<<i)+t]=(x-y+md)%md;Pw=1ll*Pw*w%md;
}
}
}
}
int fac[200005],inv[200005],ifac[200005];
inline void init(){
while((1<<lim)<=2*n)lim++;
for(int i=0;i<(1<<lim);i++)r[i]=(r[i>>1]>>1)+((i&1)<<(lim-1));
fac[0]=fac[1]=inv[0]=inv[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=n;i++)fac[i]=1ll*fac[i-1]*i%md;
for(int i=2;i<=n;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
for(int i=2;i<=n;i++)ifac[i]=1ll*inv[i]*ifac[i-1]%md;
}
int f[1<<20],g[1<<20],ltmp[400005],rtmp[400005];
int main(){
scanf("%d%d",&n,&m);init();
for(int i=0;i<=n;i++)scanf("%d",&f[i]);
for(int i=0;i<=n;i++)f[i]=1ll*f[i]*ifac[i]%md*ifac[n-i]%md;
for(int i=0;i<=n;i++)if((n-i)&1)f[i]=(md-f[i])%md;
for(int i=0;i<=2*n;i++)g[i]=pwr(m-n+i,md-2);
NTT(f,G);NTT(g,G);int INV=pwr((1<<lim),md-2);
for(int i=0;i<(1<<lim);i++)f[i]=1ll*f[i]*g[i]%md;
NTT(f,Gi);
for(int i=0;i<(1<<lim);i++)f[i]=1ll*f[i]*INV%md;
ltmp[0]=1;for(int i=1;i<=n;i++)ltmp[i]=1ll*ltmp[i-1]*(m-i)%md;
rtmp[0]=m;for(int i=1;i<=n;i++)rtmp[i]=1ll*rtmp[i-1]*(m+i)%md;
for(int i=0;i<=n;i++)f[i+n]=1ll*f[i+n]*ltmp[n-i]%md*rtmp[i]%md;
for(int i=0;i<=n;i++)printf("%d ",f[i+n]);puts("");
return 0;
}
[集训队互测 2012] calc
点击查看代码
#include<bits/stdc++.h>
using namespace std;
int k,n,md;
int fac[1005],inv[1005],ifac[1005];
inline void init(){
fac[0]=fac[1]=inv[0]=inv[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=2*n+1;i++)fac[i]=1ll*fac[i-1]*i%md;
for(int i=2;i<=2*n+1;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
for(int i=2;i<=2*n+1;i++)ifac[i]=1ll*ifac[i-1]*inv[i]%md;
}
int pre[1005],suf[1005];
inline int lag(int *f,int len,int x){
int res=0;pre[0]=x%md;suf[len+1]=1;
for(int i=1;i<=len;i++)pre[i]=1ll*pre[i-1]*(x-i)%md;
for(int i=len;i>=0;i--)suf[i]=1ll*suf[i+1]*(x-i)%md;
for(int i=0;i<=len;i++){
int tmp=1ll*f[i]*ifac[i]%md*ifac[len-i]%md*((len-i)&1?md-1:1)%md;
if(i)tmp=1ll*tmp*pre[i-1]%md;
if(i!=len)tmp=1ll*tmp*suf[i+1]%md;
res=(res+tmp)%md;
}
return res;
}
int dp[505][1005];
int main(){
scanf("%d%d%d",&k,&n,&md);
init();
dp[0][0]=1;
for(int i=1;i<=n;i++){
int sum=dp[i-1][0];
for(int j=1;j<=2*n+1;j++){
dp[i][j]=1ll*j*sum%md;
sum=(sum+dp[i-1][j])%md;
}
}
for(int i=1;i<=2*n+1;i++)dp[n][i]=(dp[n][i]+dp[n][i-1])%md;
dp[n][0]=lag(dp[n],2*n+1,k);
printf("%lld",(1ll*dp[n][0]*fac[n]%md+md)%md);
return 0;
}
[TJOI2018]教科书般的亵渎
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int md=1e9+7;
int T;
long long n;
int m;
int a[55],k,f[55],pre[55],suf[55],fac[55];
map<int,int> ma;
inline int pwr(int x,int y){
int res=1;
while(y){
if(y&1)res=1ll*res*x%md;
x=1ll*x*x%md;y>>=1;
}
return res;
}
inline int lag(long long x){
if(x<=k+2)return f[x];
int res=0;pre[0]=1;suf[k+3]=1;x%=md;
for(int i=1;i<=k+2;i++)pre[i]=1ll*pre[i-1]*(x-i)%md;
for(int i=k+2;i>=1;i--)suf[i]=1ll*suf[i+1]*(x-i)%md;
for(int i=1;i<=k+2;i++){
int x=1ll*pre[i-1]*suf[i+1]%md;
int fu=((k+2-i)&1)?-1:1;
int y=1ll*fac[i-1]*fac[k+2-i]%md*fu%md;
res=(res+1ll*f[i]*x%md*pwr(y,md-2)%md)%md;
}
return (res+md)%md;
}
inline void init(){
fac[0]=1;
for(int i=1;i<=52;i++)fac[i]=1ll*fac[i-1]*i%md;
}
int main(){
scanf("%d",&T);init();
while(T--){
scanf("%lld%d",&n,&m);k=m+1;
ma.clear();
for(int i=1;i<=m;i++)scanf("%d",&a[i]),ma[a[i]]=1;
sort(a+1,a+m+1);
while(ma[n])n--,k--,m--;
for(int i=1;i<=k+2;i++)f[i]=(f[i-1]+pwr(i,k))%md;
int res=lag(n);
for(int i=1;i<=m;i++)res=(res-pwr(a[i],k))%md;
for(int i=1;i<=m;i++)res=(res+lag(n-a[i]))%md;
for(int i=1;i<=m;i++){
for(int j=i-1;j>=1;j--){
res=(res-pwr(a[i]-a[j],k))%md;
}
}
printf("%d\n",(res+md)%md);
}
return 0;
}
[NOI2019] 机器人
点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n;
const int md=1e9+7;
namespace Lagrange{
int pre1[305],suf1[305],pre2[305],suf2[305],inv[305];
inline void init(){
inv[0]=inv[1]=1;
for(int i=2;i<=n;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
}
inline int lag(int *y,int k,int t){
if(t<=k)return y[t];
pre1[0]=suf1[0]=suf2[k+1]=1;pre2[0]=t%md;
for(int i=1;i<=k;i++)pre1[i]=1ll*pre1[i-1]*inv[i]%md;
for(int i=1;i<=k;i++)suf1[i]=1ll*suf1[i-1]*-inv[i]%md;
for(int i=1;i<=k;i++)pre2[i]=1ll*pre2[i-1]*(t-i)%md;
for(int i=k;i>=1;i--)suf2[i]=1ll*suf2[i+1]*(t-i)%md;
int res=0;
for(int i=0;i<=k;i++)res=(res+1ll*y[i]*pre1[i]%md*suf1[k-i]%md*(i?pre2[i-1]:1)%md*suf2[i+1])%md;
return (res+md)%md;
}
}
using Lagrange::lag;
int A[305],B[305],lim;
int vis[305][305],cnt,L[3005],R[3005],dp[3005][10005];
void build(int l,int r){
if(l>r||vis[l][r])return ;
vis[l][r]=++cnt;L[cnt]=l;R[cnt]=r;
if(l==r)return ;
for(int i=l;i<=r;i++){
if(abs((r-i)-(i-l))>2)continue;
build(l,i-1);build(i+1,r);
}
}
bool used[305][305];
void solve(int l,int r,int len,int v){
if(l>r||used[l][r])return ;
int id=vis[l][r];used[l][r]=1;
for(int i=1;i<=len;i++)dp[id][i]=0;
for(int i=l;i<=r;i++){
if(abs((r-i)-(i-l))>2||A[i]>v||B[i]<=v)continue;
solve(l,i-1,len,v);solve(i+1,r,len,v);
for(int j=1;j<=len;j++)dp[id][j]=(dp[id][j]+1ll*dp[vis[l][i-1]][j]*dp[vis[i+1][r]][j-1])%md;
}
for(int i=1;i<=len;i++)dp[id][i]=(dp[id][i]+dp[id][i-1])%md;
}
inline void Getval(int l,int r){
for(int i=1;i<=cnt;i++)dp[i][0]=lag(dp[i],R[i]-L[i]+1,r-l+1);
for(int l=1;l<=n;l++)for(int r=l;r<=n;r++)used[l][r]=0;
}
vector<int> hsh;
int main(){
scanf("%d",&n);
Lagrange::init();
for(int i=1;i<=n;i++)scanf("%d%d",&A[i],&B[i]),++B[i];
for(int i=1;i<=n;i++)hsh.push_back(A[i]);
for(int i=1;i<=n;i++)hsh.push_back(B[i]);
sort(hsh.begin(),hsh.end());
hsh.erase(unique(hsh.begin(),hsh.end()),hsh.end());
for(int i=1;i<=n;i++)A[i]=upper_bound(hsh.begin(),hsh.end(),A[i])-hsh.begin();
for(int i=1;i<=n;i++)B[i]=upper_bound(hsh.begin(),hsh.end(),B[i])-hsh.begin();
for(int i=0;i<=n;i++)dp[0][i]=1;
build(1,n);
for(int i=0;i+1<hsh.size();i++){
for(int l=1;l<=n;l++){
for(int r=l;r<=n;r++)if(vis[l][r])solve(l,r,min(hsh[i+1]-hsh[i],n+1),i+1);
}
Getval(hsh[i],hsh[i+1]-1);
}
printf("%d\n",dp[vis[1][n]][0]);
return 0;
}
[APIO2016]划艇
点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n;
int a[505],b[505];
vector<int> hsh;
const int md=1e9+7;
int fac[505],inv[505],ifac[505];
inline void init(){
fac[0]=fac[1]=inv[0]=inv[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=n;i++)fac[i]=1ll*fac[i-1]*i%md;
for(int i=2;i<=n;i++)inv[i]=1ll*(md-md/i)*inv[md%i]%md;
for(int i=2;i<=n;i++)ifac[i]=1ll*ifac[i-1]*inv[i]%md;
}
int tmp[505],pre[505],suf[505];
inline int lag(int *f,int len,int x){
if(x<=len)return f[x];
int res=0;pre[0]=x%md;suf[len+1]=1;
for(int i=1;i<=len;i++)pre[i]=1ll*pre[i-1]*(x-i)%md;
for(int i=len;i>=0;i--)suf[i]=1ll*suf[i+1]*(x-i)%md;
for(int i=0;i<=len;i++){
int tmp=1ll*f[i]*ifac[i]%md*ifac[len-i]%md*((len-i)&1?md-1:1)%md;
if(i)tmp=1ll*tmp*pre[i-1]%md;
if(i!=len)tmp=1ll*tmp*suf[i+1]%md;
res=(res+tmp)%md;
}
return res;
}
int dp[505][505];
int main(){
scanf("%d",&n);init();
for(int i=1;i<=n;i++)scanf("%d%d",&a[i],&b[i]),b[i]++;
for(int i=1;i<=n;i++)hsh.push_back(a[i]);
for(int i=1;i<=n;i++)hsh.push_back(b[i]);
sort(hsh.begin(),hsh.end());
hsh.erase(unique(hsh.begin(),hsh.end()),hsh.end());
for(int i=1;i<=n;i++)a[i]=upper_bound(hsh.begin(),hsh.end(),a[i])-hsh.begin();
for(int i=1;i<=n;i++)b[i]=upper_bound(hsh.begin(),hsh.end(),b[i])-hsh.begin();
for(int i=0;i<=n;i++)dp[i][0]=1;
for(int t=1;t<hsh.size();t++){
for(int i=1;i<=n;i++){
if(a[i]<=t&&b[i]>t){
int sum=dp[i-1][0];
for(int j=1;j<=n;j++){
dp[i][j]=(sum+dp[i-1][j])%md;
sum=(sum+dp[i-1][j])%md;
}
}
else for(int j=1;j<=n;j++)dp[i][j]=dp[i-1][j];
}
for(int i=1;i<=n;i++){
for(int j=1;j<=i;j++)dp[i][j]=(dp[i][j]+dp[i][j-1])%md;
dp[i][0]=lag(dp[i],i,hsh[t]-hsh[t-1]);
}
}
printf("%d",(dp[n][0]+md-1)%md);
return 0;
}