模板 - 拉格朗日插值法

update:其实拉格朗日插值法的问题在于求分母的复杂度是 \(O(n^2)\) 的,要是还要求逆元则再多一个 \(logp\) 变成 \(O(n^2logp)\),但是当一个多项式要重复使用的时候,也不必求出他的各个系数,只要预处理出各项分母的逆元之后,\(O(nlogp)\) 处理分子(求出前缀积和后缀积),然后再插值,渐进复杂度和求出系数值的一样。当分母是等差数列之类的有规律的东西,他的分母可能是可以更快求出的。


参考资料:https://attack.blog.luogu.org/solution-p4781

因为高斯消元法是n立方的,有些鬼畜问题需要n平方的拉格朗日插值法。

使用逆元的n平方logn预处理,nlogn单次询问的重心拉格朗日插值法:
(假如使用double的话,理论上预处理和询问都少个log)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;


namespace Lagrange_Interpolation_Polynomial {
    const int MOD=998244353;
    const int MAXN=2000;
    int n,x[MAXN+5],y[MAXN+5];
    int invd[MAXN+5][MAXN+5];//x_i-x_j的逆元
    int kd[MAXN+5];//k-x_i的值
    int pkd,pinvd[MAXN+5];//k-x的积,invd[i]的积
    //不需要计算k-x_i的逆元,因为复杂度其实一样

    int qpow(int x,int n)
    {
        int ans=1;
        while(n)
        {
            if(n&1)
                ans=((ll)ans*x)%MOD;
            x=((ll)x*x)%MOD;
            n>>=1;
        }
        return ans;
    }

    void init_xi_xj()
    {
        for(int i=1; i<=n; i++)
        {
            for(int j=1; j<=n; j++)
            {
                if(i==j)
                    continue;
                else
                {
                    int d=(x[i]-x[j]);
                    if(d<0)
                        d+=MOD;
                    invd[i][j]=qpow(d,MOD-2);
                }
            }
        }
        for(int i=1;i<=n;i++){
            pinvd[i]=1;
            for(int j=1;j<=n;j++){
                if(i==j)
                    continue;
                pinvd[i]=(ll)pinvd[i]*invd[i][j]%MOD;
            }
        }
    }

    void init_k(int k)
    {
        pkd=1;
        for(int i=1; i<=n; i++)
        {
            kd[i]=(k-x[i]);
            if(kd[i]<0)
                kd[i]+=MOD;
            pkd=(ll)pkd*kd[i]%MOD;
        }
    }

    inline int prod(int i,int k)
    {
        ll ans=(ll)pkd*qpow(kd[i],MOD-2)%MOD;
        ans=ans*(pinvd[i])%MOD;

        return ans;
    }

    int lagrange(int k)
    {
        ll ans=0;
        init_k(k);

        for(int i=1; i<=n; i++)
        {
            ans=(ans+(ll)y[i]*(prod(i,k)))%MOD;
        }

        return ans;
    }

}

using namespace Lagrange_Interpolation_Polynomial;



int main() {
    int k;
    scanf("%d%d",&n,&k);
    for(int i=1; i<=n; i++)
    {
        scanf("%d%d",&x[i],&y[i]);
    }
    init_xi_xj();
    printf("%d\n",lagrange(k));
}

通过n个不同的点直接构造出多项式。

$f(x)=\sum\limits_{i=1}^{n} y_i \prod\limits_{i\ne j}\frac{x-x_j}{x_i-x_j} $

正确性:代入 \(x_i\) 只有 \(y_i\) 右边是1,其他分子都有0,消掉了。

那么可以n平方构造。

需要小心溢出以及卡常数。假如进行同一个多项式的多次求 \(f(k)\) 的值 ,\(x_i-x_j\) 的逆元需要多次使用,可以像下面的类似做法预处理他们的差,当然kd和pkd要重新处理。不需要用map存,这样和直接求qpow一个鬼样。
开个invd数组,保存每两个项之间的差的逆元。

注意到prod里面有一项是 \(k-x_j\) 的积,这里也可以选择直接记录这个积。毕竟模运算比较费时,卡出来的常数可能差几倍

重心拉格朗日插值法

因为是要求逆元的,所以预处理 \(O(n^2logn)\)
然后每次求新k时,预处理重心拉格朗日插值法的g, \(O(nlogn)\)

多次求同一多项式的值快了不少的,比原版的拉格朗日插值。就是要 \(O(n^2)\) 的额外空间。

先卡掉重复求差的逆元的,已经可以通过了。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MOD=998244353;

int n,x[2005],y[2005];

int qpow(int x,int n){
    int ans=1;
    while(n){
        if(n&1)
            ans=((ll)ans*x)%MOD;
        x=((ll)x*x)%MOD;
        n>>=1;
    }
    return ans;
}

int invd[2005][2005];
int kd[2005];

void init_prod(int k){
    for(int i=1;i<=n;i++){
        kd[i]=(k-x[i]);
        if(kd[i]<MOD)
            kd[i]+=MOD;
        for(int j=1;j<=n;j++){
            if(i==j)
                continue;
            else{
                int d=x[i]-x[j];
                if(d<MOD)
                    d+=MOD;
                invd[i][j]=qpow(d,MOD-2);;
            }
        }
    }
}

int prod(int i,int k) {
    ll ans=1;
    for(int j=1;j<=n;j++){
        if(j==i)
            continue;
        else{
            ans=ans*(kd[j])%MOD*(invd[i][j])%MOD;
        }
    }

    return ans;
}

int lagrange(int k) {
    ll ans=0;
    init_prod(k);

    for(int i=1;i<=n;i++){
        ans=(ans+(ll)y[i]*(prod(i,k)))%MOD;
    }

    return ans;
}

int main() {
#ifdef Yinku
    freopen("Yinku.in","r",stdin);
#endif // Yinku
    int k;
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++){
        scanf("%d%d",&x[i],&y[i]);
    }
    printf("%d\n",lagrange(k));
}

再卡掉n倍logn,效果不是很明显,毕竟是n平方的算法:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MOD=998244353;

int n,x[2005],y[2005];

int qpow(int x,int n){
    int ans=1;
    while(n){
        if(n&1)
            ans=((ll)ans*x)%MOD;
        x=((ll)x*x)%MOD;
        n>>=1;
    }
    return ans;
}

int invd[2005][2005];
int kd[2005],pkd;

void init_prod(int k){
    pkd=1;
    for(int i=1;i<=n;i++){
        kd[i]=(k-x[i]);
        if(kd[i]<MOD)
            kd[i]+=MOD;
        pkd=(ll)pkd*kd[i]%MOD;
        for(int j=1;j<=n;j++){
            if(i==j)
                continue;
            else{
                int d=x[i]-x[j];
                if(d<MOD)
                    d+=MOD;
                invd[i][j]=qpow(d,MOD-2);;
            }
        }
    }
}

int prod(int i,int k) {
    ll ans=(ll)pkd*qpow(kd[i],MOD-2)%MOD;
    for(int j=1;j<=n;j++){
        if(j==i)
            continue;
        else{
            ans=ans*(invd[i][j])%MOD;
        }
    }

    return ans;
}

int lagrange(int k) {
    ll ans=0;
    init_prod(k);

    for(int i=1;i<=n;i++){
        ans=(ans+(ll)y[i]*(prod(i,k)))%MOD;
    }

    return ans;
}

int main() {
#ifdef Yinku
    freopen("Yinku.in","r",stdin);
#endif // Yinku
    int k;
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++){
        scanf("%d%d",&x[i],&y[i]);
    }
    printf("%d\n",lagrange(k));
}

也可以通过之前用高斯消元法做的https://codeforces.com/contest/1155/problem/E,但是高斯消元法在处理多次询问的时候貌似复杂度好一些?(其实是我写挂了,询问变成n平方log的了,已订正)

实测比高斯消元快一些,估计是高斯消元写得不好。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;


const int MOD=1000003;
int inv[MOD+5];

namespace Lagrange_Interpolation_Polynomial {
    const int MAXN=10;
    int n,x[MAXN+5],y[MAXN+5];
    int invd[MAXN+5][MAXN+5];//x_i-x_j的逆元
    int kd[MAXN+5];//k-x_i的值
    int pkd,pinvd[MAXN+5];//k-x的积,invd[i]的积
    //不需要计算k-x_i的逆元,因为复杂度其实一样

    int qpow(int x,int n)
    {
        int ans=1;
        while(n)
        {
            if(n&1)
                ans=((ll)ans*x)%MOD;
            x=((ll)x*x)%MOD;
            n>>=1;
        }
        return ans;
    }

    void init_xi_xj()
    {
        for(int i=1; i<=n; i++)
        {
            for(int j=1; j<=n; j++)
            {
                if(i==j)
                    continue;
                else
                {
                    int d=(x[i]-x[j]);
                    if(d<0)
                        d+=MOD;
                    invd[i][j]=inv[d];
                }
            }
        }
        for(int i=1;i<=n;i++){
            pinvd[i]=1;
            for(int j=1;j<=n;j++){
                if(i==j)
                    continue;
                pinvd[i]=(ll)pinvd[i]*invd[i][j]%MOD;
            }
        }
    }

    void init_k(int k)
    {
        pkd=1;
        for(int i=1; i<=n; i++)
        {
            kd[i]=(k-x[i]);
            if(kd[i]<0)
                kd[i]+=MOD;
            pkd=(ll)pkd*kd[i]%MOD;
        }
    }

    inline int prod(int i,int k)
    {
        ll ans=(ll)pkd*inv[kd[i]]%MOD;
        ans=ans*(pinvd[i])%MOD;

        return ans;
    }

    int lagrange(int k)
    {
        ll ans=0;
        init_k(k);

        for(int i=1; i<=n; i++)
        {
            ans=(ans+(ll)y[i]*(prod(i,k)))%MOD;
        }

        return ans;
    }

}

using namespace Lagrange_Interpolation_Polynomial;


ll jury(ll x) {
    ll a[11]= {10002,0,1};
    ll ret=0;
    ll tx=1;
    for(int i=0; i<=10; i++) {
        ret=(ret+tx*a[i]%MOD)%MOD;
        tx=tx*x%MOD;
    }
    //cout<<"jury: "<<ret<<endl;
    return ret;
}


int query(int x) {
    printf("? %d\n",x);
    fflush(stdout);
    scanf("%d",&x);
    //x=jury(x);
    return x;
}

void answer(int x) {
    printf("! %d\n",x);
}


int main() {
    n=11;

    inv[1]=1;
    for(int i=2;i<MOD;i++){
        inv[i]=(ll)inv[MOD%i]*(MOD-MOD/i)%MOD;
    }

    for(int i=0; i<=10; i++) {
        x[i+1]=i;
        int res=query(i);
        if(res==0) {
            answer(i);
            exit(0);
        }
        y[i+1]=res;
    }

    init_xi_xj();
    for(int i=11;i<MOD;i++){
        int res=lagrange(i);
        //cout<<"res="<<res<<endl;
        //cout<<"jury="<<jury(i)<<endl;
        if(res==0){
            answer(i);
            exit(0);
        }
    }
    answer(-1);
}

一般我们代入的点都是连续自然数,当 \(x_i\) 的取值是连续自然数时,插值法因为分母变得容易处理而变成n复杂度。

以从0开始取举例。

假设当前n为4,从1开始取。其实分子无论怎么取值都是可以用前缀积和后缀积预处理出来的,然后两段乘在一起。关键是分母,当分母是连续自然数的时候不必两两枚举。

\(y_1\frac{1*(x-2)(x-3)(x-4)}{(1-2)(1-3)(1-4)}+y_2\frac{(x-1)*(x-3)(x-4)}{(2-1)(2-3)(2-4)}+y_3\frac{(x-1)(x-2)*(x-4)}{(3-1)(3-2)(3-4)}+y_4\frac{(x-1)(x-2)(x-3)*1}{(4-1)(4-2)(4-3)}\)

分母是:\(1*(-1)(-2)(-3),(1)*(-1)*(-2),(2)(1)*(-1),(3)(2)(1)\)

也就是前面阶乘,后面是负数的阶乘,负数的阶乘也很好想,比如这里n=4,那么第一项有3个负数,阶乘为负数,第二项为正数,第三项为负数……也就是n-i为奇数的项是负数。

可不断插点的重心拉格朗日插值法。

原式
$f(x)=\sum\limits_{i=1}^{n} y_i \prod\limits_{i\ne j}\frac{x-x_j}{x_i-x_j} $

把分子补上缺的项,提出来,记 \(g=\prod\limits_{i=1}^{n} x-x_i\)
$f(x)=g\sum\limits_{i=1}^{n} \frac{1}{x-x_i} \prod\limits_{i\ne j}\frac{y_i}{x_i-x_j} $

再记 \(h_i= \prod\limits_{i\ne j}\frac{y_i}{x_i-x_j}\)
$f(x)=g\sum\limits_{i=1}^{n} \frac{h_i}{x-x_i} $

那么每次多加一个点的时候,更新 \(g\) 以及每一个 \(h_i\) 就可以了。


经典例题:求 $\sum\limits_{i=1}^{n} i^k $ n很大但k只有几千万(需要连续前缀优化)
https://www.luogu.org/problemnew/show/P4593
推出计算式之后,就是计算上面那个东西。注意步骤,首先要确定多项式的系数,一个k次多项式的和一般都是k+1次的,这里是m+1次多项式的和,也就是m+2次多项式,需要m+3个点。

注意用拉格朗日插值时,判断k与x_i不重复再插,不然返回y_i就可以了。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MOD=1e9+7;

namespace Lagrange_Interpolation_Polynomial {
    const int MAXN=80;
    int n,x[MAXN+5],y[MAXN+5];
    int invd[MAXN+5][MAXN+5];//x_i-x_j的逆元
    int kd[MAXN+5];//k-x_i的值
    int pkd,pinvd[MAXN+5];//k-x的积,invd[i]的积
    //不需要计算k-x_i的逆元,因为复杂度其实一样

    int qpow(int x,int n)
    {
        int ans=1;
        while(n)
        {
            if(n&1)
                ans=((ll)ans*x)%MOD;
            x=((ll)x*x)%MOD;
            n>>=1;
        }
        return ans;
    }

    void init_xi_xj()
    {
        for(int i=1; i<=n; i++)
        {
            //printf("x[%d]=%d y[%d]=%d\n",i,x[i],i,y[i]);
            for(int j=1; j<=n; j++)
            {
                if(i==j)
                    continue;
                else
                {
                    int d=(x[i]-x[j]);
                    if(d<0)
                        d+=MOD;
                    invd[i][j]=qpow(d,MOD-2);
                }
            }
        }
        for(int i=1;i<=n;i++){
            pinvd[i]=1;
            for(int j=1;j<=n;j++){
                if(i==j)
                    continue;
                pinvd[i]=(ll)pinvd[i]*invd[i][j]%MOD;
            }
            //cout<<"pinvd["<<i<<"]="<<pinvd[i]<<endl;
        }
    }

    void init_k(int k)
    {
        pkd=1;
        for(int i=1; i<=n; i++)
        {
            kd[i]=(k-x[i]);
            if(kd[i]<0)
                kd[i]+=MOD;
            pkd=(ll)pkd*kd[i]%MOD;
        }
        //cout<<"pkd="<<pkd<<endl;
    }

    inline int prod(int i,int k)
    {
        ll ans=(ll)pkd*qpow(kd[i],MOD-2)%MOD;
        ans=ans*(pinvd[i])%MOD;

        //cout<<"in prod ans="<<ans<<endl;
        return ans;
    }

    int lagrange(int k)
    {
        for(int i=1;i<=n;i++){
            if(x[i]==k)
                return y[i];
        }
        ll ans=0;
        //cout<<"k="<<n<<endl;
        init_k(k);

        for(int i=1; i<=n; i++)
        {
            ans=(ans+(ll)y[i]*(prod(i,k)))%MOD;
            //cout<<"in lag ans="<<ans<<endl;
        }
        //cout<<"(1~n)i^k="<<ans<<endl;
        return ans;
    }

}

using namespace Lagrange_Interpolation_Polynomial;



ll init(int pn,int k){
    //拉格朗日插值的n需要的是多项式的点的数目
    n=pn;

    for(int i=1;i<=n;i++){
        x[i]=i;
        y[i]=0;
        for(int j=1;j<=i;j++){
            y[i]=((ll)y[i]+qpow(j,k))%MOD;
        }
    }
    //cerr<<"n="<<n<<endl;;
    init_xi_xj();
}

ll calc(ll r){
    //cerr<<"lag"<<endl;
    ll t=lagrange(r);
    //cout<<"calc"<<" r="<<t<<endl;
    return t;
}

vector<int> vm;


int main() {
    int t;
    scanf("%d",&t);
    while(t--){
        ll n;
        int m;
        scanf("%lld%d",&n,&m);

        int k=m+1;
        //m+1次方的和,m+2次多项式,需要m+3个点
        vm.clear();
        for(int i=0;i<m;i++){
            int tt;
            scanf("%d",&tt);
            vm.push_back(tt);
        }
        sort(vm.begin(),vm.end());
        n%=MOD;
        //m+1次方的和

        //m+2次多项式,需要m+3个点
        init(k+2,k);
        ll ans=0;
        while(n){
            ans=(ans+calc(n))%MOD;
            for(auto &mi:vm){
                ans=(ans-qpow(mi,k))%MOD;
                if(ans<0)
                    ans+=MOD;
            }
            //cout<<"ans=!!!   "<<ans<<endl;
            if(vm.size()==0)
                break;
            vector<int> tv;
            for(int i=1;i<vm.size();i++){
                tv.push_back(vm[i]-vm[0]);
            }

            n-=vm[0];
            vm=tv;
            //vm.erase(ve.begin());
        }
        cout<<ans%MOD<<endl;
    }
}
posted @ 2019-04-25 19:40  韵意  阅读(730)  评论(0编辑  收藏  举报