拉格朗日插值
全文基本照抄 OI-wiki,这个博客存在的意义只是自己查起来方便。
Introduction
众所周指,\(n\) 个点值唯一确定一个 \(n-1\) 次多项式。假设 \(n\) 个点值分别为 \(f(x_i)=y_i\;(1\le i\le n)\),那么一个对于任意一个 \(k\),\(f(k)\) 可以通过如下公式得出:
证明见 OI Wiki。
不难发现,通过这个公式暴力拆解还可以得到多项式的系数表示。
令 \(s_i=y_i\prod_{j\not=i}\frac{1}{x_i-x_j}\),这是很容易在 \(\mathcal O(n^2)\) 以内算出的,此时 \(f(k)=\sum_{i} y_is_i\prod_{j\not=i} (k-x_j)\)。\(\mathcal O(n^2)\) 预处理 \(\prod_j (k-x_j)\) 的系数后,对于每个 \(i\) 可以在线性多项式除法得到 \(\prod_{j\not=i} (k-x_j)\) 的系数。这样总的复杂度是 \(\mathcal O(n^2)\),比直接高斯消元优很多。
推导
假设 \(n\) 个点分别是 \(P_i(x_i,y_i)\;(i\in[1,n])\),第 \(i\) 个点在 \(x\) 轴上的投影为 \(P'(x_i,0)\)。考虑 \(n\) 个函数 \(f_i(x)\),第 \(i\) 个函数过 \(\begin{cases}P_j, &j=i\\P'_j, &j\neq i\end{cases}\) 这些点,那么 \(f(x)=\sum_i f_i(x)\) 就是一个满足条件的函数。
根据代数基本定理,设 \(f_i(x)=a\prod_i (x-x_i)\),带入 \(P_i(x_i,y_i)\) 得到 \(a=\dfrac{y_i}{\prod_{j\neq i} (x_i-x_j)}\),则:
横坐标是连续整数的拉格朗日插值
如果已知点的横坐标是连续整数,我们可以做到 \(O(n)\) 插值。
设要求 \(n\) 次多项式为 \(f(x)\),我们已知 \(f(1),\cdots,f(n+1)\)(\(1\le i\le n+1\)),考虑代入上面的插值公式:
后面的累乘可以分子分母分别考虑,不难得到分子为:
分母的 \(i-j\) 累乘可以拆成两段阶乘来算:
于是横坐标为 \(1,\cdots,n+1\) 的插值公式:
预处理 \((x-i)\) 前后缀积、阶乘阶乘逆,然后代入这个式子,复杂度为 \(O(n)\)。
【21 ZR联赛集训 day8】连通
设 \(\mathrm{dp}(u,k)\) 表示以 \(u\) 为根的子树内,包含 \(u\)、大小为 \(k\) 的所有连通块答案之和。将其写成生成函数的形式:\(F_u(z)=\sum_i \mathrm{dp}(u,i)z^i\),转移就是 \(F_u(z)=a_uz\prod_{v\in son_u}(F_v(z)+1)\)。代 \(n+1\) 个点值 \(0,1,\cdots,n\) 进去然后简单 dp 可以所有的 \(F_{x}(i)\),最后再拉格朗日插值把每个 \(F\) 的系数插出来即可。加一些 trival 的前缀和之类的可以做到时间复杂度 \(\mathcal O(n^2)\)。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;
#define mp make_pair
#define pb push_back
#define fi first
#define se second
inline int read()
{
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return x*f;
}
void write(int n)
{
if(n<0){putchar('-');n=-n;}
if(n>9)write(n/10);
putchar(n%10^48);
}
const int mod=998244353;
int qpow(int a,int n)
{
int ans=1;
while(n)
{
if(n&1)ans=1ll*ans*a%mod;
a=1ll*a*a%mod;
n>>=1;
}
return ans;
}
const int N=6e3+10,M=1e4+10;
int head[N],ver[M],nxt[M],tot=0,inv[N];
void add(int x,int y)
{
ver[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
}
int a[N],f[N][N],f1[N][N],pre[N],suf[N],inv1[N];
void dfs(int x,int u,int fa)
{
int ans=1ll*a[u]*x%mod;
for(int i=head[u];i;i=nxt[i])
{
int v=ver[i];if(v==fa)continue;
dfs(x,v,u);
ans=1ll*ans*(f1[x][v]+1)%mod;
}
f1[x][u]=ans;
}
void dfs1(int x,int u,int fa)
{
if(u==1)f[x][u]=f1[x][u];
else f[x][u]=((f[x][fa]-1ll*f[x][fa]*inv[u]%mod+mod)%mod+f1[x][u])%mod;
for(int i=head[u];i;i=nxt[i])
{
int v=ver[i];if(v==fa)continue;
dfs1(x,v,u);
}
}
int dp[N],n,k,gg[N],G[N][N];
void mul(int x)
{
for(int i=n+1;i>=0;i--)
{
if(i)gg[i]=(gg[i-1]+1ll*x*gg[i])%mod;
else gg[i]=1ll*x*gg[i]%mod;
}
}
void div(int x)//prod_i(z-i) / (z-i)
{
for(int i=0;i<=k;i++)
{
if(!i)G[x][i]=(mod-gg[i])*1ll*inv1[x]%mod;
else G[x][i]=(G[x][i-1]-gg[i]+mod)%mod*1ll*inv1[x]%mod;
}
for(int i=1;i<=k;i++)G[x][i]+=G[x][i-1],G[x][i]%=mod;
}
int b[N],s[N];
void calc(int u)
{
for(int i=0;i<=n;i++)b[i]=f[i][u]*1ll*s[i]%mod;
for(int i=0;i<=n;i++)dp[u]+=1ll*b[i]*G[i][k]%mod,dp[u]%=mod;
}
int main()
{
n=read(),k=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=n;i++)inv1[i]=qpow(i,mod-2);
for(int i=1;i<n;i++){int u=read(),v=read();add(u,v),add(v,u);}
gg[0]=1;for(int i=0;i<=n;i++)
{
dfs(i,1,-1);
pre[0]=1,suf[n+1]=1;
for(int j=1;j<=n;j++)pre[j]=1ll*pre[j-1]*(f1[i][j]+1)%mod;
for(int j=n;j;j--)suf[j]=1ll*suf[j+1]*(f1[i][j]+1)%mod;
int tmp=qpow(pre[n],mod-2);
for(int j=1;j<=n;j++)inv[j]=1ll*tmp*pre[j-1]%mod*suf[j+1]%mod;
dfs1(i,1,-1),mul(mod-i);
}
for(int i=0;i<=n;i++)div(i);
for(int i=0;i<=n;i++)
{
s[i]=1;
for(int j=0;j<=n;j++)
{
if(i==j)continue;
s[i]=(i-j+mod)%mod*1ll*s[i]%mod;
}
s[i]=qpow(s[i],mod-2);
}
for(int i=1;i<=n;i++)calc(i);
for(int i=1;i<=n;i++)write(dp[i]),putchar(' ');
}