拉格朗日插值

拉格朗日插值主要用于求解如下问题:

给出\(n\)个二维点\((x_i,y_i)\),找出过所有点的多项式\(f(x)\)\(x\)处的取值(通常\(x\)较大)

考虑对于每个点构造函数\(f_i(x)\)使得\(f_i(x_i)=y_i\),且\(\forall x_j(j\neq i) f_i(x_j)=0\)

如何满足后式?\(f_i(x)=g(x)\cdot\prod_{j\neq i}(x-x_j)\)

然后我们需要凑出前式,使\(g(x)=y_i\prod_{j\neq x}\frac{1}{x_i-x_j}\)即可。

综上所述:

\[f_i(x)=y_i\prod_{j\neq i}\frac{x-x_j}{x_i-x_j} \]

然后:

\[f(x)=\sum_{i=1}^{n}y_i\prod_{j\neq i}\frac{x-x_j}{x_i-x_j} \]

代入计算即可,时间复杂度\(O(n^2)\)

然后你就可以AC这道模板题

代码如下:

int Lagrange(int *x, int *y, int n, int k)
{
    int res=0;
    for (int i=1; i<=n; i++)
    {
        int s1=1, s2=1;
        for (int j=1; j<=n; j++) if (i^j)
        {
            s1=1ll*s1*(k-x[j])%P;
            s2=1ll*s2*(x[i]-x[j])%P;
        }
        res=(res+1ll*(1ll*s1*Pow(s2, P-2)%P)*y[i])%P;
    }
    return (res%P+P)%P;
}

如果\(x_i\)取值是一段连续的数时,该算法是否有优化空间?

答案是肯定的,考虑\(x_i\in[1,n]\),那么

\[f(x)=\sum_{i=1}^{n}y_i\prod_{i\neq j}\frac{x-j}{i-j} \]

如何快速计算\(y_i\prod_{i\neq j}\frac{x-j}{i-j}\)?维护前缀与后缀即可

\[pre_i=\prod_{j=1}^{i}(x-j) \]

\[suf_i=\prod_{j=i}^{n}(x-j) \]

\[f(x)=\sum_{i=1}^{n}y_i\frac{pre_{i-1}\cdot suf_{i+1}}{fac_i\cdot fac_{n-i}} \]

预处理阶乘逆元,前缀积,后缀积即可,时间复杂度\(O(n)\)

代码如下:

int Lagrange(int *y,int n,int k)
{
    int ans=0, fac=1; pre[0]=suf[n+1]=1; ifac[n]=Pow(fac, P-2);
    for (int i=1; i<=n; i++) pre[i]=1ll*pre[i-1]*(k-i)%P;
    for (int i=n; i; i--) suf[i]=1ll*suf[i+1]*(k-i)%P;
    for (int i=1; i<=n; i++) fac=1ll*fac*i%P;
    for (int i=n-1; i; i--) ifac[i]=1ll*ifac[i+1]*(i+1)%P;
    for (int i=1; i<=n; i++)
    {
        int s1=1ll*pre[i-1]*suf[i+1]%P;
        int s2=1ll*ifac[i-1]*ifac[n-i]%P;
        ans=(ans+1ll*((n-i)&1?-1:1)*s1*s2%P*y[i])%P;
    }
    return (ans%P+P)%P;
}

看一道例题

CF622F The Sum of the k-th Powers

题目传送门

Description

\(\sum_{i=1}^{n}i^m\),对\(10^9+7\)取模,\(n\leq10^9,m\leq10^6\)

Solution

这是一个题面自带题解的题。

题面给出了以下恒等式:

\[\sum_{i=1}^{n}=\frac{n(n+1)}{2} \]

\[\sum_{i=1}^{n}i^2=\frac{n(n+1)(2n+1)}{6} \]

\[\sum_{i-1}^{n}i^3=(\frac{n(n+1)}{2})^2 \]

经验告诉我们CF给的提示都是很有用的。

这三个等式暗示了什么?

\(\sum_{i=1}^{n}i^m\)是一个(m+1)次多项式!

那么就可以插值了,把\(k\in[1,m ]\)的值代入计算,然后用\(O(n)\)\(O(nlogn)\)的插值即可。

下面是复杂度为\(O(nlogn)\)的实现

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=1000000007, N=1000005;
int pre[N], suf[N], fac[N];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    return x*f;
}

int Pow(int x, int t)
{
    int res=1;
    while (t)
    {
        if (t&1) (res*=x)%=mod;
        (x*=x)%=mod; t>>=1;
    }
    return res;
}

signed main()
{
    int n=read(), k=read(), ans=0, y=0; pre[0]=suf[k+3]=fac[0]=1;
    for (int i=1; i<=k+2; i++) pre[i]=pre[i-1]*(n-i)%mod;
    for (int i=k+2; i; i--) suf[i]=suf[i+1]*(n-i)%mod;
    for (int i=1; i<=k+2; i++) fac[i]=fac[i-1]*i%mod;
    for (int i=1; i<=k+2; i++)
    {
        (y+=Pow(i, k))%=mod;
        int s1=((k-i)&1?-1:1)*pre[i-1]*suf[i+1]%mod;
        int s2=fac[i-1]*fac[k+2-i]%mod;
        (ans+=s1*Pow(s2, mod-2)%mod*y%mod)%=mod;
    }
    printf("%d\n", (ans+mod)%mod);
    return 0;
}

再看一个难度大一些的题

CF995F Cowmpany Cowmpensation

题目传送门

Description

给定一棵\(N(N\leq3000)​\)个点的有根树,给每个节点赋一个\(\leq D(D\leq10^9)​\)的值,并且保证儿子节点的值\(\leq​\)父亲节点的值,求方案数。

Solution

先写一个\(O(nD)\)\(DP\)方程。

\(dp_{i,j}\)表示\(i\)号节点取值为\(j\)时子树内的总方案数

\[dp_{i,j}=\prod_{p\in son_i}\sum_{k=1}^{j}dp_{p,k} \]

考虑优化\(DP\),因为\(D\)的值过大,所以我们要让复杂度与\(D\)无关。

于是我们想到了拉格朗日插值。

我们先看出来,\(dp_{i,j}​\)是一个关于\(j​\)的多项式。

然后它的次数还要很小,其实它的次数不超过\(n​\)

考虑证明?

我不会

\(\forall u\in leaf,dp_{u,D}=D=f^1(D)​\)

\(\forall u\notin leaf,dp_{u, D}=\prod_{son}f^x(D)\)

感性理解一下,叶子结点为关于\(D\)的一次多项式,非叶子节点为儿子节点的积,对于指数就是和,于是根节点的次数就为叶子结点数我也不知道对不对

给出\(O(n^2)\)预处理\(dp_{n,n}\)以内的\(DP\)值和\(O(n^2)\)插值的代码。

#include<bits/stdc++.h>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
const int P=1000000007, N=3005;
int x[N], y[N], f[N][N], n, D;
vector<int> G[N];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    return x*f;
}

int Pow(int x, int t)
{
    int res=1;
    for (; t; t>>=1, x=1ll*x*x%P) if (t&1) res=1ll*res*x%P;
    return res;
}

int Lagrange(int *x, int *y, int n, int X)
{
    int res=0;
    for (int i=0; i<=n; i++)
    {
        int s1=1, s2=1;
        for (int j=0; j<=n; j++) if (i^j)
        {
            s1=1ll*s1*(X-x[j])%P;
            s2=1ll*s2*(x[i]-x[j])%P;
        }
        res=(res+1ll*(1ll*s1*Pow(s2, P-2)%P)*y[i])%P;
    }
    return (res%P+P)%P;
}

void dfs(int u)
{
    rep(i, 1, n) f[u][i]=1;
    for (int v: G[u])
        {dfs(v); rep(j, 1, n) f[u][j]=1ll*f[u][j]*f[v][j]%P;}
    rep(i, 1, n) f[u][i]=(f[u][i]+f[u][i-1])%P;
}

int main()
{
    n=read(); D=read();
    rep(i, 2, n) G[read()].push_back(i);
    dfs(1);
    rep(i, 0, n) x[i]=i, y[i]=f[1][i];
    printf("%d\n", Lagrange(x, y, n, D));
    return 0;
}

最后一道毒瘤题

[国家集训队]calc

题目传送门

Description

构造序列\(a_1,\dots,a_n\),满足\(a_1,\dots,a_n\in[1,A]\),且\(a_1,\dots,a_n\)互不相等。定义合法序列的值为\(\prod_{i=1}^{n}a_i\),求不同合法序列的值的和。\(n\leq500,A\leq10^9\)

Solution

神仙题\(Orz\)

可以看看我借鉴的这篇题解

依旧先写\(O(nA)\)\(DP\)方程。

\(f_{i,j}\)表示前\(i\)个数取\([1,j]\)的不同合法序列的值的和,仅考虑递增序列。

\[f_{i,j}=f_{i-1,j-1}\cdot j+f_{i, j-1} \]

\[Ans=f_{n, A}\cdot n! \]

我们又要猜结论了!\(QAQ\)

\(f_{n,i}\)为关于\(i\)\(g(n)\)次多项式。

引理:\(f^n(x)-f^n(x-1)\)\(f^{n-1}(x)\)

\(n\)次多项式的差分为\(n-1\)次多项式。

·证明我不会(我想大家都会

考虑\(DP\)方程中也有差分的形式

\[f_{n,i}-f_{n,i-1}=f_{n-1,i-1}\cdot i \]

于是

\[g(n)-1=g(n-1)+1 \]

又有

\[g(0)=0 \]

所以

\[g(n)=2n \]

所以\(f_(n,i)\)是关于\(i\)\(2n\)次多项式

求出\(f_{n,1}\)\(f_{n,2n+1}\)即可

\(O(n^2)\)\(DP\)\(O(n^2)\)的插值代码:

#include<bits/stdc++.h>
using namespace std;
const int N=5005;
int x[N], y[N], f[N][N<<2];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    return x*f;
}

int Pow(int x, int t, int P)
{
    int res=1;
    for (; t; t>>=1, x=1ll*x*x%P) if (t&1) res=1ll*res*x%P;
    return res;
}

int Lagrange(int *x, int *y, int n, int k, int P)
{
    int res=0;
    for (int i=1; i<=n; i++)
    {
        int s1=1, s2=1;
        for (int j=1; j<=n; j++) if (i^j)
        {
            s1=1ll*s1*(k-x[j])%P;
            s2=1ll*s2*(x[i]-x[j])%P;
        }
        res=(res+1ll*(1ll*s1*Pow(s2, P-2, P)%P)*y[i])%P;
    }
    return (res%P+P)%P;
}

int main()
{
    int A=read(), n=read(), m=2*n+1, P=read(), fac=1, res=0;
    for (int i=1; i<=n; i++) fac=1ll*fac*i%P;
    for (int i=0; i<=m; i++) f[0][i]=1;
    for (int i=1; i<=n; i++)
        for (int j=1; j<=m; j++)
            f[i][j]=(1ll*f[i-1][j-1]*j+f[i][j-1])%P;
    for (int i=1; i<=m; i++) x[i]=i;
    for (int i=1; i<=m; i++) y[i]=f[n][i];
    printf("%d\n", 1ll*fac*Lagrange(x, y, m, A, P)%P);
    return 0;
}

写在最后

拉格朗日插值是一种通过点值转换为插值的算法,当然还有\(O(nlog^2n)\)的快速插值,但仅出现在少数毒瘤多项式中。

拉格朗日插值的一个经典应用是优化\(O(nm)\)\(DP\),当\(DP\)值是一个关于\(m\)\(n\)的级别次多项式,且\(m\)极大时。它的数据范围也很有辨识度对吧

posted @ 2019-04-04 00:39  OIerC  阅读(697)  评论(0编辑  收藏  举报