【XSY2978】Product of Roots(多项式)
题面&题意
已知 \(f(x),g(x),h(x)\) 能表示成:
现在给你 \(f(x)\) 和 \(g(x)\),要求 \(h(x)\)。
\(n,m\leq 10^5,k\leq \min(10^5,nm+1)\)。
题解
考虑取对数然后泰勒展开,设 \(l_i(x)=\ln(a_ix+1)\):
引理:\(l_i^{(n)}(x)=\big[\ln(a_ix+1)\big]^{(n)}=(-1)^{n-1}(n-1)!a_i^n(a_ix+1)^{-n}\)(\(n>0\))。
证明:
考虑归纳证明。
首先当 \(n=1\) 时,\(\big[\ln(a_ix+1)\big]'=\ln'(a_ix+1)\cdot\big(a_ix+1\big)'=\dfrac{1}{a_ix+1}\cdot a=a_i(a_ix+1)^{-1}\)。
考虑由 \(l_i^{(n)}(x)\) 往 \(l_i^{(n+1)}(x)\) 推导:
设 \(s(x)=(-1)^{n-1}(n-1)!a_i^nx^{-n}\),\(t(x)=a_ix+1\),那么 \(s'(x)=(-1)^nn!a_i^nx^{-(n+1)}\),\(t'(x)=a_i\)。
则:
\[\begin{aligned} l_i^{(n+1)}(x)&=\left[l_i^{(n)}(x)\right]'\\ &=\bigg[s\big(t(x)\big)\bigg]'=s'\big(t(x)\big)t'(x)\\ &=(-1)^nn!a_i^n(a_ix+1)^{-(n+1)}\cdot a_i\\ &=(-1)^nn!a_i^{n+1}(a_ix+1)^{-(n+1)} \end{aligned} \]证毕。
代入得:
那么易得:
\(g(x)\) 同理。
同样地,我们对 \(h(x)\) 也取对数:
注意到 \(\ln(a_ib_jx+1)\) 和 \(\ln(a_ix+1)\) 形式上是相似的,所以我们也同理地得到:
考虑往 \(f(x)\) 和 \(g(x)\) 的方向凑:
那么:
那么我们可以先通过 \(\ln\big(f(x)\big)\) 和 \(\ln\big(g(x)\big)\) 求出 \(\ln\big(h(x)\big)\),然后再求 \(\exp\) 就可以求出 \(h(x)\) 了。
代码如下:
#include<bits/stdc++.h>
#define LN 19
#define N 100010
using namespace std;
namespace modular
{
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
int n,m,k;
int inv[N<<2];
int rev[N<<2],w[LN][N<<2][2];
int f[N<<2],g[N<<2],h[N<<2];
int lnf[N<<2],lng[N<<2],lnh[N<<2];
void init(int limit)
{
for(int i=0;i<limit;i++) inv[i]=poww(i,mod-2);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
int len=mid<<1;
int gn=poww(3,(mod-1)/len);
int ign=poww(gn,mod-2);
int g=1,ig=1;
for(int j=0;j<mid;g=mul(g,gn),ig=mul(ig,ign),j++)
w[bit][j][0]=g,w[bit][j][1]=ig;
}
}
void NTT(int *a,int limit,int opt)
{
opt=(opt<0);
for(int i=0;i<limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int i=0,len=mid<<1;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
int x=a[i+j],y=mul(w[bit][j][opt],a[i+mid+j]);
a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
}
}
}
if(opt)
{
int tmp=poww(limit,mod-2);
for(int i=0;i<limit;i++)
a[i]=mul(a[i],tmp);
}
}
void getdao(int *f,int *g,int n)
{
for(int i=1;i<n;i++)
g[i-1]=mul(i,f[i]);
g[n-1]=0;
}
void getint(int *f,int *g,int n)
{
for(int i=n-1;i>=1;i--)
g[i]=mul(inv[i],f[i-1]);
g[0]=0;
}
void getinv(int *f,int *g,int n)
{
static int ff[N<<2];
g[0]=poww(f[0],mod-2);
int now=2;
for(;now<(n<<1);now<<=1)
{
int limit=now<<1;
for(int i=0;i<now;i++) ff[i]=f[i];
NTT(ff,limit,1),NTT(g,limit,1);
for(int i=0;i<limit;i++)
g[i]=mul(dec(2,mul(ff[i],g[i])),g[i]);
NTT(g,limit,-1);
for(int i=now;i<limit;i++) g[i]=0;
}
for(int i=n;i<now;i++) g[i]=0;
for(int i=0;i<now;i++) ff[i]=0;
}
void getln(int *f,int *g,int n)
{
static int daof[N<<2],invf[N<<2],daog[N<<2];
getdao(f,daof,n);
getinv(f,invf,n);
int limit=1;
while(limit<(n<<1)) limit<<=1;
NTT(daof,limit,1),NTT(invf,limit,1);
for(int i=0;i<limit;i++) daog[i]=mul(daof[i],invf[i]);
NTT(daog,limit,-1);
getint(daog,g,n);
for(int i=0;i<limit;i++) daof[i]=invf[i]=daog[i]=0;
}
void getexp(int *f,int *g,int n)
{
static int lng[N<<2],ff[N<<2];
g[0]=1;
int now=2;
for(;now<(n<<1);now<<=1)
{
int limit=now<<1;
getln(g,lng,now);
for(int i=0;i<now;i++) ff[i]=dec(f[i],lng[i]);
ff[0]=add(ff[0],1);
NTT(g,limit,1),NTT(ff,limit,1);
for(int i=0;i<limit;i++) g[i]=mul(g[i],ff[i]);
NTT(g,limit,-1);
for(int i=now;i<limit;i++) g[i]=0;
}
for(int i=n;i<now;i++) g[i]=0;
for(int i=0;i<now;i++) lng[i]=ff[i]=0;
}
int main()
{
n=read()+1,m=read()+1,k=read();
int limit=1;
while(limit<(k<<1)) limit<<=1;
init(limit);
for(int i=0;i<n;i++) f[i]=read();
for(int i=0;i<m;i++) g[i]=read();
getln(f,lnf,k),getln(g,lng,k);
for(int i=0;i<k;i++)
{
lnh[i]=mul(i,mul(lnf[i],lng[i]));
if(!(i&1)) lnh[i]=dec(0,lnh[i]);
}
getexp(lnh,h,k);
for(int i=0;i<k;i++)
printf("%d ",h[i]);
return 0;
}
/*
2 2 5
1 2 1
1 2 1
*/