Luogu P5205 【模板】多项式开根
Description
给定一个 \(n-1\) 次多项式 \(A(x)\),求一个在\(\bmod x^n\)意义下的多项式 \(B(x)\),使得 \(B^2(x) \equiv A(x)(\bmod x^n)\)
多项式的系数在 \(\bmod\ 998244353\) 的意义下进行运算。
\(n \leq 10^5,a_i \in [0,998244352] \cap \mathbb{Z}\)
Solution
其实推导过程和多项式求逆类似
考虑倍增
假设我们已经求出了一个多项式 \(G(x)\) 使得 \(G^2(x) \equiv A(x) \ (\bmod\ x^{\lceil \frac{n}{2}\rceil})\) ,而 \(B(x)\) 本来就有 \(B^2(x) \equiv A(x) \ (\bmod\ x^{\lceil \frac{n}{2}\rceil})\) ,那么
平方差公式展开
在这里我们需要说一下究竟取哪个,又有什么区别的问题
假设题目要求的最终的答案为 \(F(x)\)
因为是在模大质数意义下进行的运算,所以要么有 \(B(x)\equiv G(x)(\bmod x^{\lceil \frac{n}{2}\rceil})\) ,要么有 \(B(x)+G(x)\equiv 0(\bmod x^{\lceil \frac{n}{2}\rceil})\) ,至于为什么只需要关注一下 \(0\) 次项的系数就可以了
若我们在倍增的过程中全部选择 \(B(x)\equiv G(x)(\bmod x^{\lceil \frac{n}{2}\rceil})\) 或选择了偶数次 \(B(x)+G(x)\equiv 0(\bmod x^{\lceil \frac{n}{2}\rceil})\) ,那么最后得到的答案就是 \(F(x)\),反之若我们选择了奇数次 \(B(x)+G(x)\equiv 0(\bmod x^{\lceil \frac{n}{2}\rceil})\) ,那么最后得到的答案就是 \(-F(x)\) ,原因在下面的推导中不难看出。所以 \(\sqrt{A(x)}\) 有两解,为 \(\pm F(x)\)
我们选择前者,即
移项后平方展开得到
即
移项得
然后除过去,得
多项式求逆+\(\text{NTT}\)即可
#include<cstdio>
#include<iostream>
using namespace std;
const int N=1e5+10;
const int mod=998244353;
const int g=3;
const int invg=332748118;
int n,a[N<<2],b[N<<2],c[N<<2],d[N<<2],f[N<<2],h[N<<2],k;
inline void Add(int &x,int y){x+=y;x-=x>=mod? mod:0;}
inline int MOD(int x){x-=x>=mod? mod:0;return x;}
inline int fas(int x,int p){int res=1;while(p){if(p&1)res=1ll*res*x%mod;p>>=1;x=1ll*x*x%mod;}return res;}
inline void NTT(int *a,int f){
for(register int i=0,j=0;i<k;i++){
if(i>j)swap(a[i],a[j]);
for(register int l=k>>1;(j^=l)<l;l>>=1);}
for(register int i=1;i<k;i<<=1){
int w=fas(~f? g:invg,(mod-1)/(i<<1));
for(register int j=0;j<k;j+=(i<<1)){
int e=1;
for(register int p=0;p<i;p++,e=1ll*e*w%mod){
int x=a[j+p],y=1ll*a[j+p+i]*e%mod;
a[j+p]=MOD(x+y);a[j+p+i]=MOD(x-y+mod);
}
}
}
}
inline void PINV(int *a,int *b,int deg){
if(deg==1){b[0]=fas(a[0],mod-2);return;}
int M=(deg+1)>>1;PINV(a,b,M);
k=1;while(k<=deg+deg-2)k<<=1;int INV=fas(k,mod-2);
for(register int i=0;i<deg;i++)h[i]=a[i];
for(register int i=deg;i<k;i++)h[i]=0;
NTT(h,1);NTT(b,1);
for(register int i=0;i<k;i++)
b[i]=(2ll-1ll*h[i]*b[i]%mod+mod)*b[i]%mod;
NTT(b,-1);
for(register int i=0;i<deg;i++)b[i]=1ll*b[i]*INV%mod;
for(register int i=deg;i<k;i++)b[i]=0;
}
inline void Sqrt(int *a,int *b,int deg){
if(deg==1){b[0]=1;return;}
int M=(deg+1)>>1;Sqrt(a,b,M);
k=1;while(k<=deg+deg-2)k<<=1;int INV=fas(k,mod-2);
for(register int i=0;i<deg;i++)c[i]=b[i];
for(register int i=deg;i<k;i++)c[i]=0;
NTT(c,1);
for(register int i=0;i<k;i++)c[i]=1ll*c[i]*c[i]%mod;
NTT(c,-1);
for(register int i=0;i<deg;i++)c[i]=1ll*c[i]*INV%mod;
for(register int i=deg;i<k;i++)c[i]=0;
for(register int i=0;i<deg;i++)Add(c[i],a[i]);
for(register int i=0;i<deg;i++)d[i]=MOD(b[i]+b[i]);
for(register int i=deg;i<k;i++)d[i]=0;
for(register int i=0;i<k;i++)f[i]=0;
PINV(d,f,deg);
k=1;while(k<=deg+deg-2)k<<=1;
NTT(f,1);NTT(c,1);
for(register int i=0;i<k;i++)b[i]=1ll*f[i]*c[i]%mod;
NTT(b,-1);
for(register int i=0;i<deg;i++)b[i]=1ll*b[i]*INV%mod;
for(register int i=deg;i<k;i++)b[i]=0;
}
int main(){
scanf("%d",&n);n--;
for(register int i=0;i<=n;i++)scanf("%d",&a[i]);
Sqrt(a,b,n+1);
for(register int i=0;i<=n;i++)printf("%d%c",b[i],i==n? '\n':' ');
return 0;
}