拉格朗日插值

全文基本照抄 OI-wiki,这个博客存在的意义只是自己查起来方便。

Introduction

众所周指,\(n\) 个点值唯一确定一个 \(n-1\) 次多项式。假设 \(n\) 个点值分别为 \(f(x_i)=y_i\;(1\le i\le n)\),那么一个对于任意一个 \(k\)\(f(k)\) 可以通过如下公式得出:

\[f(k)=\sum_{i=1}^ny_i\prod_{j\neq i}\frac{k-x_j}{x_i-x_j} \]

证明见 OI Wiki

不难发现,通过这个公式暴力拆解还可以得到多项式的系数表示。

\(s_i=y_i\prod_{j\not=i}\frac{1}{x_i-x_j}\),这是很容易在 \(\mathcal O(n^2)\) 以内算出的,此时 \(f(k)=\sum_{i} y_is_i\prod_{j\not=i} (k-x_j)\)\(\mathcal O(n^2)\) 预处理 \(\prod_j (k-x_j)\) 的系数后,对于每个 \(i\) 可以在线性多项式除法得到 \(\prod_{j\not=i} (k-x_j)\) 的系数。这样总的复杂度是 \(\mathcal O(n^2)\),比直接高斯消元优很多。

推导

假设 \(n\) 个点分别是 \(P_i(x_i,y_i)\;(i\in[1,n])\),第 \(i\) 个点在 \(x\) 轴上的投影为 \(P'(x_i,0)\)。考虑 \(n\) 个函数 \(f_i(x)\),第 \(i\) 个函数过 \(\begin{cases}P_j, &j=i\\P'_j, &j\neq i\end{cases}\) 这些点,那么 \(f(x)=\sum_i f_i(x)\) 就是一个满足条件的函数。

根据代数基本定理,设 \(f_i(x)=a\prod_i (x-x_i)\),带入 \(P_i(x_i,y_i)\) 得到 \(a=\dfrac{y_i}{\prod_{j\neq i} (x_i-x_j)}\),则:

\[f_i(x)=y_i\prod_{j\neq i}\frac{x-x_j}{x_i-x_j}\\f(x)=\sum_{i} f_i(x)=\sum_{i}y_i\prod_{j\neq i}\frac{x-x_j}{x_i-x_j} \]

横坐标是连续整数的拉格朗日插值

如果已知点的横坐标是连续整数,我们可以做到 \(O(n)\) 插值。

设要求 \(n\) 次多项式为 \(f(x)\),我们已知 \(f(1),\cdots,f(n+1)\)\(1\le i\le n+1\)),考虑代入上面的插值公式:

\[\begin{aligned} f(x)&=\sum_{i=1}^{n+1}y_i\prod_{j\ne i}\frac{x-x_j}{x_i-x_j}\\ &=\sum_{i=1}^{n+1}y_i\prod_{j\ne i}\frac{x-j}{i-j} \end{aligned} \]

后面的累乘可以分子分母分别考虑,不难得到分子为:

\[\dfrac{\prod_{j=1}^{n+1}(x-j)}{x-i} \]

分母的 \(i-j\) 累乘可以拆成两段阶乘来算:

\[(-1)^{n+1-i}\cdot(i-1)!\cdot(n+1-i)! \]

于是横坐标为 \(1,\cdots,n+1\) 的插值公式:

\[f(x)=\sum_{i=1}^{n+1}y_i\cdot\frac{\prod_{j=1}^{n+1}(x-j)}{(x-i)\cdot(-1)^{n+1-i}\cdot(i-1)!\cdot(n+1-i)!} \]

预处理 \((x-i)\) 前后缀积、阶乘阶乘逆,然后代入这个式子,复杂度为 \(O(n)\)

【21 ZR联赛集训 day8】连通

\(\mathrm{dp}(u,k)\) 表示以 \(u\) 为根的子树内,包含 \(u\)、大小为 \(k\) 的所有连通块答案之和。将其写成生成函数的形式:\(F_u(z)=\sum_i \mathrm{dp}(u,i)z^i\),转移就是 \(F_u(z)=a_uz\prod_{v\in son_u}(F_v(z)+1)\)。代 \(n+1\) 个点值 \(0,1,\cdots,n\) 进去然后简单 dp 可以所有的 \(F_{x}(i)\),最后再拉格朗日插值把每个 \(F\) 的系数插出来即可。加一些 trival 的前缀和之类的可以做到时间复杂度 \(\mathcal O(n^2)\)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;
#define mp make_pair
#define pb push_back
#define fi first
#define se second
inline int read()
{
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return x*f;
}
void write(int n)
{
	if(n<0){putchar('-');n=-n;}
	if(n>9)write(n/10);
	putchar(n%10^48);
}
const int mod=998244353;
int qpow(int a,int n)
{
	int ans=1;
	while(n)
	{
		if(n&1)ans=1ll*ans*a%mod;
		a=1ll*a*a%mod;
		n>>=1;
	}
	return ans;
}
const int N=6e3+10,M=1e4+10;
int head[N],ver[M],nxt[M],tot=0,inv[N];
void add(int x,int y)
{
	ver[++tot]=y;
	nxt[tot]=head[x];
	head[x]=tot;
}
int a[N],f[N][N],f1[N][N],pre[N],suf[N],inv1[N];
void dfs(int x,int u,int fa)
{
	int ans=1ll*a[u]*x%mod;
	for(int i=head[u];i;i=nxt[i])
	{
		int v=ver[i];if(v==fa)continue;
		dfs(x,v,u);
		ans=1ll*ans*(f1[x][v]+1)%mod;
	}
	f1[x][u]=ans;
}
void dfs1(int x,int u,int fa)
{
	if(u==1)f[x][u]=f1[x][u];
	else f[x][u]=((f[x][fa]-1ll*f[x][fa]*inv[u]%mod+mod)%mod+f1[x][u])%mod;
	for(int i=head[u];i;i=nxt[i])
	{
		int v=ver[i];if(v==fa)continue;
		dfs1(x,v,u);
	}
}
int dp[N],n,k,gg[N],G[N][N];
void mul(int x)
{
	for(int i=n+1;i>=0;i--)
	{
		if(i)gg[i]=(gg[i-1]+1ll*x*gg[i])%mod;
		else gg[i]=1ll*x*gg[i]%mod;
	}
}
void div(int x)//prod_i(z-i)  /   (z-i)
{
	for(int i=0;i<=k;i++)
	{
		if(!i)G[x][i]=(mod-gg[i])*1ll*inv1[x]%mod;
		else G[x][i]=(G[x][i-1]-gg[i]+mod)%mod*1ll*inv1[x]%mod;
	}
	for(int i=1;i<=k;i++)G[x][i]+=G[x][i-1],G[x][i]%=mod;
}
int b[N],s[N];
void calc(int u)
{
	for(int i=0;i<=n;i++)b[i]=f[i][u]*1ll*s[i]%mod;
	for(int i=0;i<=n;i++)dp[u]+=1ll*b[i]*G[i][k]%mod,dp[u]%=mod;
}
int main()
{
	n=read(),k=read();
	for(int i=1;i<=n;i++)a[i]=read(); 
	for(int i=1;i<=n;i++)inv1[i]=qpow(i,mod-2);
	for(int i=1;i<n;i++){int u=read(),v=read();add(u,v),add(v,u);}
	gg[0]=1;for(int i=0;i<=n;i++)
	{
		dfs(i,1,-1);
		pre[0]=1,suf[n+1]=1;
		for(int j=1;j<=n;j++)pre[j]=1ll*pre[j-1]*(f1[i][j]+1)%mod;
		for(int j=n;j;j--)suf[j]=1ll*suf[j+1]*(f1[i][j]+1)%mod;
		int tmp=qpow(pre[n],mod-2);
		for(int j=1;j<=n;j++)inv[j]=1ll*tmp*pre[j-1]%mod*suf[j+1]%mod;
		dfs1(i,1,-1),mul(mod-i);
	}
	for(int i=0;i<=n;i++)div(i);
	for(int i=0;i<=n;i++)
	{
		s[i]=1;
		for(int j=0;j<=n;j++)
		{
			if(i==j)continue;
			s[i]=(i-j+mod)%mod*1ll*s[i]%mod; 
		}
		s[i]=qpow(s[i],mod-2);
	}
	for(int i=1;i<=n;i++)calc(i);
	for(int i=1;i<=n;i++)write(dp[i]),putchar(' ');
}

References

https://www.luogu.com.cn/problem/solution/P4781

https://oi-wiki.org/math/poly/lagrange/

posted @ 2021-10-26 15:33  zzt1208  阅读(123)  评论(2编辑  收藏  举报