多项式全家桶

题解:

准备这一段时间主要写

数学+数据结构+多项式

下一段时间写

计算几何+字符串

图论和dp感觉简单点的大家都会 难的大家也差不多

参考链接:https://www.cnblogs.com/zhoushuyu/p/8763215.html

ntt模板就之前的没啥变化

fft把处理w[i]那里改了一下 据说可能会有精度问题

还是写普通fft吧。。虽然mx的减少了一倍dft的次数

1.多项式求逆:

$$ A(x)B(x) \equiv 1 \ ( \ mod \ x^n)  $$

这个我们采用递归求解

先求解

$$ A(x)B'(x) \equiv 1 \ ( \ mod \ x^{\frac{n}{2}})$$

减一减可以得到

$$ B(x)-B'(x) \equiv 0 \  ( \ mod \ x^{\frac{n}{2}}) $$

两边平方

$$ B^2(x)-2B'(x)B(x)+{B'}^{2}(x) \equiv 0 \ ( \ mod \ x^n) $$

这一步后面模数平方的原因是 $x^k=x^a*x^b$

当$k<=n$时 $a,b$一定有一个$<=\frac{n}{2}$ 为0

然后乘上$A(x)$ 得到

$$B(x) \equiv 2B'(x)-A(x)*B'^{2}(x) \ ( \ mod \ x^n) $$

然后这个东西可以fft/ntt求

$$n+n/2+n/4+n/8+...=2n$$

所以总复杂度$nlogn$ 常数是6倍

另外多项式存在逆元的条件是常数项有逆元

这个东西怎么手动验算呢

我们直接将$A(x)B(x)$乘起来就好,然后与1相减判断一下是不是$x^n$的$g(x)$倍

写这个代码的时候注意空间开4倍

因为你当前需要的最高项m是$2*len$,然后你n的上限是$2m$,所以就是$4*len$了

$now$的下一层是$(now+1)/2$而不是$now/2$

原因是如果取$now/2$的话,上面证明平方后是$now-1$而非$now$

至于中间数组可以开在递归内部避免清空了

另外把n,m写错了查了我一个小时。。

2.多项式开方

$$ A(x) \equiv B(x)*B(x) \ (\ mod \ x^n) $$

依旧采用递归求解,按照上面可以得到

$$ (B(x)+B'(x))*(B(x)-B'(x)) \equiv 0 \ (\ mod \ x^{\frac{n}{2}}) $$

这样的话会有两个解,我们取

$$B(x)-B'(x) \equiv 0 \ ( \ mod \ x^{\frac{n}{2}}) $$

两边平方

$$ B^2(x)-2B'(x)B(x)+{B'}^{2}(x) \equiv 0 \ ( \  mod \ x^n)$$

因为$ B^2(x) \equiv A(x) \ ( \ mod \ x^n) $

$$ B(x) \equiv \frac{A(x)+{B'}^{2}(x)}{2B'(x)} \ ( \ mod \ x^n)$$

这个东西手动验算方法和上面的一样

不过递归到$len==1$的时候,我们需要处理$i^2=k$这么一样事情

如果不是模意义下倒是好求,模意义下需要用二次剩余

百度了一下好麻烦啊不太想学。。(大家好像都没学)

3.多项式求ln

$$ ln(A(x)) \equiv B(x) $$

两边求导

$$ \frac{A'(x)}{A(x)} \equiv B'(x) $$

然后积分

4.牛顿迭代

现在已知一个函数G(x),求一个多项式F(x),满足

$$ G(F(x)) \equiv 0 \ ( \ mod \ x^n) $$

其实应该先学这个东西的 上面的东西都可以用这个做

$$ G(F'(x)) \equiv 0 \ ( \ mod \ x^{\frac{n}{2}}) $$

然后看如何扩展到$x^n$

使用泰勒展开

$$ G(F(x))=G(F(x))+\frac{G'(F'(x))}{1!}(F(x)-F'(x))+.... $$

我们会发现 $(F(x)-F'(x))^k$ 在$k>=2$的时候 $mod x^n =0$

又因为$G(F(z)) \equiv 0 \ ( \ mod x^n) $,得到

$$ F(x) \equiv F'(x)-\frac{G(F'(x))}{G'(F'(x))} \ ( \ mod \ x^n ) $$

可以试一下用牛顿迭代去做前面的求逆和开方 效果是一样的

因为不想打了。。。放个链接

http://blog.miskcoo.com/2015/06/polynomial-with-newton-method

多项式求逆yyb的博客上有,但证明不是很全 $\frac{1}{A(x)}$变成$B’(x)$需要一些说明

因为本来应该是变成$B(x)$的

5.多项式求exp

$$ e^{(A(x))} \equiv B(x)  \ ( \ mod \ x^n)$$

$$ A(x) \equiv ln(B(x)) \ ( \ mod \ x^n) $$

然后构造$ G(B(x))=B(x)-A(x) $

利用牛顿迭代,得到

$$ B(x)=B'(x)-\frac{ln(B'(x))-A(x)}{\frac{1}{B'(x)}} $$

化简一下得到

$$ B(x)=B'(x)(1-ln(B'(x))+A(x)) $$

其中要保证 初始A(0)=0,而求出的B(0)=1

至于上面两个如果要手动验算的话,我们可以都选择验算${e}^{f(x)}$

$$ {e}^{f(x)}=\prod_{i=1}^{n} {{e}^{ai*x^i}}$$

由泰勒展开

$$e^x=1+\frac{x}{1!}+\frac{x^2}{2!}+\frac{x^3}{3!}+...$$

$${e}^{f(x)}=\prod_{i=0}^{n} {\sum_{k=0}^{INF} \frac{{(ai*x^i)}^{k}}{k!} }$$

然后你会发现如果初始$a[0] \not = 0$的话

不妨设$a[0]=1$当i=0时,有$\sum_{k=0}^{INF} {\frac{1}{k!}}$ 就么一个常数项 就比较gg了

另外为啥$b[0]=1$呢 我们注意到当$i \not = 0$的时候

只有$k=0$才有常数项 而$k=0$时常数项都是1,所以最后也是1 

现在举个例子手算一下${e}^{x^2+2x}$

可以得到答案

以上都是我写代码之前想的

写完代码之后我发现,你只要把exp和ln都写了

先exp再ln回去不就行了吗

先求逆再乘一下不就行了么

当然万一错了。。可以用小数据按照上面的式子算

 

#include <bits/stdc++.h>
using namespace std;
#define rint register int
#define IL inline
#define rep(i,h,t) for(int i=h;i<=t;i++)
#define dep(i,t,h) for(int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define mep(x,y) memcpy(x,y,sizeof(y))
#define mid (t<=0?(h+t-1)/2:(h+t)/2)
namespace IO{
    char ss[1<<24],*A=ss,*B=ss;
    IL char gc()
    {
        return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
    }
    template<class T> void read(T &x)
    {
        rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; 
    }
    char sr[1<<24],z[20]; ll Z,C1=-1;
    template<class T>void wer(T x)
    {
        if (x<0) sr[++C1]='-',x=-x;
        while (z[++Z]=x%10+48,x/=10);
        while (sr[++C1]=z[Z],--Z);
    }
    IL void wer1()
    {
        sr[++C1]=' ';
    }
    IL void wer2()
    {
        sr[++C1]='\n';
    }
    template<class T>IL void maxa(T &x,T y) {if (x<y) x=y;}
    template<class T>IL void mina(T &x,T y) {if (x>y) x=y;} 
    template<class T>IL T MAX(T x,T y){return x>y?x:y;}
    template<class T>IL T MIN(T x,T y){return x<y?x:y;}
};
using namespace IO;
const double ee=1.00000000000000;
const double pi=acos(-1.0);
const int N=4e5+10;
int a[N],b[N],w[N],r[N],n1[N],n2[N],n,m,l,inv[N];
const int mo=998244353;
const int G=3;
IL int fsp(int x,int y)
{
    ll now=1;
    while (y)
    {
        if (y&1) now=now*x%mo;
        x=1ll*x*x%mo;
        y>>=1;
    }
    return now;
}
IL void ntt_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
IL void clear()
{
    for (int i=0;i<=n;i++) a[i]=b[i]=0;
}
/*
struct cp{
    double a,b;
    cp operator +(const &o) const
    {
        return (cp){a+o.a,b+o.b};
    }
    cp operator -(const &o) const
    {
        return (cp){a-o.a,b-o.b};
    }
    cp operator *(const &o) const
    {
        return (cp){a*a.o-b*b.o,o.a*b+o.b*a);
    }
};
IL void ftt_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) w[i]=cp(cos(pi*i/n),sin(pi*i/n));
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
void ftt(int *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i<<=1)
        for (int j=0;j<n;j+=(i*2))
          for (int k=0;k<i,k++)
          {
              cp W=w[n/i*k]; W.b*=o;
              cp x=a[j+k],y=w*a[i+j+k];
              a[j+k]=w+y; a[i+j+k]=x-y;
          }
    if (o==-1) for (int i=0;i<n;i++) a[i].a/=ee*n;
}*/
void ntt(int *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i<<=1)
    {
        int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
        rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo;
        for (int j=0;j<n;j+=(i*2))
          for (int k=0;k<i;k++)
          {
              int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
              a[j+k]=(x+y)%mo; a[i+j+k]=(x-y)%mo;
          }
    }
    if (o==-1)
    {
      reverse(&a[1],&a[n]);
      for (int i=0,inv=fsp(n,mo-2);i<n;i++)
        a[i]=1ll*a[i]*inv%mo;
    }
}
IL void getcj(int *A,int *B,int len)
{
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    ntt(a,1); ntt(b,1);
    for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=a[i];
    clear();
}
IL void getinv(int *A,int *B,int len)
{
    if (len==1) { B[0]=fsp(A[0],mo-2); return; }
    getinv(A,B,(len+1)>>1);
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    ntt(a,1); ntt(b,1);
    for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=(2*B[i]%mo-a[i])%mo; 
    clear();
}
IL void getsqrt(int *A,int *B,int len)
{
    int inv2=fsp(2,mo-2);
    if (len==1) {B[0]=sqrt(A[0]); return;}
    getsqrt(A,B,(len+1)>>1);
    int C[N]={};
    getinv(B,C,len);
    getcj(A,C,len);
    for (int i=0;i<len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo;
}
IL void getDao(int *a,int *b,int len)
{
    for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo;
    b[len-1]=0;
}
IL void getjf(int *a,int *b,int len)
{
    for (int i=0;i<len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo;
    b[0]=0;
}
IL void getln(int *A,int *B,int len)
{
    int C[N],D[N];
    getDao(A,C,len);
    getinv(A,D,len);
    getcj(C,D,len);
    getjf(D,B,len);
}
IL void getexp(int *A,int *B,int len)
{
    if (len==1) {B[0]=1; return;}
    getexp(A,B,(len+1)>>1);
    int C[N];
    getln(B,C,len);
    for(int i=0;i<len;i++) C[i]=(-C[i]+A[i])%mo;
    C[0]=(C[0]+1)%mo;
    getcj(C,B,len);
}
int main()
{
    freopen("1.in","r",stdin);
    freopen("1.out","w",stdout);
    inv[1]=1;
    rep(i,2,1e5+20) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; 
    int len;
    read(len);
    rep(i,0,len-1) read(n1[i]);
/*    getinv(n1,n2,len); */
/*    getsqrt(n1,n2,len);
    rep(i,0,len-1) wer((n2[i]+mo)%mo),wer1();
    getcj(n2,n2,len);
    wer2(); 
    rep(i,0,len-1) wer((n1[i]-n2[i])%mo),wer1();*/
/*    getln(n1,n2,len);
    rep(i,0,len-1) wer((n2[i]+mo)%mo),wer1(); */
/*    getexp(n1,n2,len);
    rep(i,0,len-1) wer((n2[i]+mo)%mo),wer1(); */
    fwrite(sr,1,C1+1,stdout);
    return 0;
}

 

 

 

6.多项式快速幂

可以使用两边取$ln$

但这玩意直接快速幂$nlog^2$跑的要比exp快啊。。

7.某种分治fft

问题:计算$\prod_{i=1}^{n} {1+a[i]x}$

要是我们从左到右过去这么算复杂度就是$n^2logn$的了

我们要去平均系数

所以就使用分治了

每一层的复杂度都是$nlogn$

$$T(n)=nlogn+2T(n/2)$$

复杂度当然就是$nlog^2{n}$的了

扩展:

这个非常的套路

对于$x \in [1,m]$ 计算$\sum_{i=1}^{n} {a[i]^x}$

令$f(x)=\prod _{i=1}^{n} {(1+a[i]x)}$ 首先先分治fft算出这个东西

两边求导

$$ln(f(x))=\sum_{i=1}^{n} {ln(1+a[i]x)} $$

对于后面那个泰勒展开

$$ln(f(x))=\sum_{i=1}^{n} {\sum_{j=1}^{INF} {\frac{{(-1)}^{j-1}{(a[i]x)}^{j}}{j}}}$$

然后交换一下枚举顺序

得到

$$ln(f(x))=\sum_{j=1}^{INF} {\frac{{(-1)}^{j-1}}{j} \ \sum_{i=1}^{n}} {{(a[i]x)}^{j}} $$

也就是得到了我们要求的东西

8.多项式除法&&多项式取模

 

已知n次多项式$A(x)$和m次多项式$B(x)$

求商$C(x)$,余式$D(x)$

$$A(x)=B(x)*C(x)+D(x)$$

$$A(\frac{1}{x})=B(\frac{1}{x})*C(\frac{1}{x})+D(\frac{1}{x})$$

$$x^n*A(\frac{1}{x})=x^m*B(\frac{1}{x})*{x}^{n-m}*C(\frac{1}{x})+x^{n-m+1}*x^{m-1}*D(\frac{1}{x}) $$

然后注意一下$x^n*A(\frac{1}{x})$ 假设A是n次多项式,那么等价于将它系数翻转

我们把式子对$x^{n-m+1}$取模 得到

$$A'(x)=B'(x)C'(x)$$

$$C'(x)=\frac{A'(x)}{B'(x)} $$

那么多项式求逆,再翻转一下系数就可以求出C

然后$D(x)=A(x)-C(x)*B(x)$

#注意:

1.前面的运算是在$ mod \ ( \  x^{n-m+1} \ ) \ $进行的,所以长度为$n-m+2$

2.算出k之后要清空它后面的项,不然就gg了

写的时候并没有注意常数。。可能有点大

另外下次应该得改一改getcj里传入的参数。。

不然两个串不同引起了浪费。。

 

#include <bits/stdc++.h>
using namespace std;
#define rint register int
#define IL inline
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define me(x) memset(x,0,sizeof(x))
#define ll long long
namespace IO{
    char ss[1<<24],*A1=ss,*B1=ss;
    IL char gc()
    {
        return A1==B1&&(B1=(A1=ss)+fread(ss,1,1<<24,stdin),A1==B1)?EOF:*A1++;
    }
    template<class T>void read(T &x)
    {
        rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f;
    }
    char sr[1<<24],z[20]; int Z,C1=-1;
    template<class T>void wer(T x)
    {
        if (x<0) sr[++C1]='-',x=-x;
        while (z[++Z]=x%10+48,x/=10);
        while (sr[++C1]=z[Z],--Z);
    }
    IL void wer1()
    {
        sr[++C1]=' ';
    }
    IL void wer2()
    {
        sr[++C1]='\n';
    }
};
using namespace IO;
const int N=12.1e5;
int n,m,l,k1[N],k2[N],k3[N],a[N],b[N],r[N],w[N];
const int mo=998244353;
const int G=3;
IL int fsp(int x,int y)
{
    ll now=1;
    while (y)
    {
        if (y&1) now=now*x%mo;
        x=1ll*x*x%mo;
        y>>=1;
    }
    return now;
}
IL void ntt_init()
{
    l=0;
    for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
IL void clear()
{
    rep(i,0,n) a[i]=b[i]=0;
}
IL void ntt(int *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i*=2)
    {
      int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
      for (int j=1;j<i;j++) w[j]=1ll*w[j-1]*wn%mo;
      for (int j=0;j<n;j+=(i*2))
        for (int k=0;k<i;k++)
        {
            int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
            a[j+k]=(x+y)%mo; a[i+j+k]=(x-y)%mo;
        }
    }
    if (o==-1)
    {
        reverse(&a[1],&a[n]);
        int inv=fsp(n,mo-2);
        for (int i=0;i<n;i++) a[i]=1ll*a[i]*inv%mo;
    }
}
IL void getcj(int *A,int *B,int len)
{
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    ntt(a,1); ntt(b,1);
    for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<m;i++) B[i]=a[i];
    clear();
}
int C[N];
IL void getinv(int *A,int *B,int len)
{
    if (len==1) { B[0]=fsp(A[0],mo-2); return; }
    getinv(A,B,(len+1)/2);
    rep(i,0,len) C[i]=B[i];
    getcj(A,C,len); getcj(B,C,len);
    rep(i,0,len) B[i]=(2ll*B[i]-C[i])%mo;
}
int main()
{
    freopen("1.in","r",stdin);
    freopen("1.out","w",stdout);
    int n1,m1;
    read(n1); read(m1);
    rep(i,0,n1) read(k1[i]);
    rep(i,0,m1) read(k2[i]);
    reverse(k1,k1+n1+1);
    reverse(k2,k2+m1+1);
    int k=n1-m1+2;
    getinv(k2,k3,k);
    getcj(k1,k3,k);
    reverse(k3,k3+n1-m1+1);
    rep(i,0,n1-m1) wer((k3[i]+mo)%mo),wer1();
    rep(i,n1-m1+1,4*m1) k3[i]=0;
    wer2();
    reverse(k1,k1+n1+1);
    reverse(k2,k2+m1+1);
    getcj(k3,k2,m1);
    rep(i,0,m1-1) wer(((k1[i]-k2[i])%mo+mo)%mo),wer1();
    fwrite(sr,1,C1+1,stdout);
    return 0;
}

 

 

 

9.任意模数多项式乘法(MTT)

另一种方法是3模数ntt再crt合并 还是这种比较简单

为啥不能直接fft呢,虽然里面是double可最后你还得取模的呀

计算$F(x)*G(x)$

设$M=\sqrt{n},F(x)=M*A(x)+B(x),G(x)=M*C(x)+D(x)$

然后乘一下,得到

$F(x)*G(x)=M^2*C(x)*A(x)+M*(B(x)C(x)+A(x)D(x))+B(x)*D(x)$

注意取$\sqrt{n}$的意思是所有数里面最大的数 所以暴力一点直接30000就可以了

 

#include <bits/stdc++.h>
using namespace std;
#define rint register int
#define IL inline
#define rep(i,h,t) for(int i=h;i<=t;i++)
#define dep(i,t,h) for(int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define mep(x,y) memcpy(x,y,sizeof(y))
#define mid (t<=0?(h+t-1)/2:(h+t)/2)
namespace IO{
    char ss[1<<24],*A=ss,*B=ss;
    IL char gc()
    {
        return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
    }
    template<class T> void read(T &x)
    {
        rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; 
    }
    char sr[1<<24],z[20]; ll Z,C1=-1;
    template<class T>void wer(T x)
    {
        if (x<0) sr[++C1]='-',x=-x;
        while (z[++Z]=x%10+48,x/=10);
        while (sr[++C1]=z[Z],--Z);
    }
    IL void wer1()
    {
        sr[++C1]=' ';
    }
    IL void wer2()
    {
        sr[++C1]='\n';
    }
    template<class T>IL void maxa(T &x,T y) {if (x<y) x=y;}
    template<class T>IL void mina(T &x,T y) {if (x>y) x=y;} 
    template<class T>IL T MAX(T x,T y){return x>y?x:y;}
    template<class T>IL T MIN(T x,T y){return x<y?x:y;}
};
using namespace IO;
long double ee=1.00000000000000;
long double pi=std::acos(-1.0);
const int N=4e5+10;
int r[N],n,m,l,inv[N],mo,mo2;
int x1[N],x2[N];
struct cp{
    long double a,b;
    cp operator +(const cp &o) const
    {
        return (cp){a+o.a,b+o.b};
    }
    cp operator -(const cp &o) const
    {
        return (cp){a-o.a,b-o.b};
    }
    cp operator *(const cp &o) const
    {
        return (cp){a*o.a-b*o.b,o.a*b+o.b*a};
    }
}a[N],b[N],w[N];
IL int fsp(int x,int y)
{
    ll now=1;
    while (y)
    {
        if (y&1) now=now*x%mo;
        x=1ll*x*x%mo;
        y>>=1;
    }
    return now;
}
IL void clear()
{
    for (int i=0;i<=n;i++) a[i].a=a[i].b=b[i].a=b[i].b=0;
}
IL void fft_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) w[i]=(cp){std::cos(pi*i/n),std::sin(pi*i/n)};
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
void fft(cp *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i<<=1)
        for (int j=0;j<n;j+=(i*2))
          for (int k=0;k<i;k++)
          {
              cp W=w[n/i*k]; W.b*=o;
              cp x=a[j+k],y=W*a[i+j+k];
              a[j+k]=x+y; a[i+j+k]=x-y;
          }
    if (o==-1) for (int i=0;i<n;i++) a[i].a/=ee*n;
}
IL void getcj(cp *A,cp *B,int len)
{
    m=len;
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    fft_init();
    fft(a,1); fft(b,1);
    for (int i=0;i<n;i++) a[i]=a[i]*b[i];
    fft(a,-1);
    for (int i=0;i<len;i++) B[i].a=(ll)(a[i].a+0.5)%mo;
    clear();
}
IL void getcj2(int *A,int *B,int len)
{
    cp a1[N],a2[N],a3[N],b1[N],b2[N];
    for (int i=0;i<len;i++)
    {
      a3[i].a=a1[i].a=A[i]/mo2,a2[i].a=A[i]%mo2;
      b1[i].a=B[i]/mo2,b2[i].a=B[i]%mo2;
    }
    int now[N]={};
    getcj(b1,a1,len);
    for (int i=0;i<len;i++) now[i]=(now[i]+(ll)(a1[i].a)*mo2%mo*mo2%mo)%mo;
    getcj(a2,b1,len);
    for (int i=0;i<len;i++) now[i]=(now[i]+(ll)(b1[i].a)*mo2%mo)%mo;
    getcj(b2,a3,len);
    for (int i=0;i<len;i++) now[i]=(now[i]+(ll)(a3[i].a)*mo2%mo)%mo;
    getcj(a2,b2,len);
    for (int i=0;i<len;i++) now[i]=(now[i]+(ll)(b2[i].a))%mo;
    for (int i=0;i<len;i++) B[i]=now[i];
}
int main()
{
    freopen("1.in","r",stdin);
    freopen("1.out","w",stdout);
    int n1,m1;
    read(n1); read(m1); read(mo); mo2=30000;
    rep(i,0,n1) read(x1[i]);
    rep(i,0,m1) read(x2[i]);
    getcj2(x1,x2,n1+m1+1);
    rep(i,0,n1+m1) wer(x2[i]),wer1();
    fwrite(sr,1,C1+1,stdout);
    return 0;
}

 注意由于负数/法和取模与正数并不相同,所以我们要保证乘的式子都是非负的

【模板】多项式求逆(加强版) 刚开始就因为这样发现过不了样例

这题比较卡常数。。所以得优化一下

把暴力4次卷积变成

4次dft过来再3次dft回去 常数小了一倍

另外不要在循环里面a[]={} 这个挺花时间的

我这样每次inv用了14次dft

inv里面a*b*b可以再优化一下 b只用算一次 这样是12次dft

利用mx的论文的技巧可以做到7还是8次dft

// luogu-judger-enable-o2
#include <bits/stdc++.h>
using namespace std;
#define rint register int
#define IL inline
#define rep(i,h,t) for(int i=h;i<=t;i++)
#define dep(i,t,h) for(int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define mep(x,y) memcpy(x,y,sizeof(y))
#define mid (t<=0?(h+t-1)/2:(h+t)/2)
namespace IO{
    char ss[1<<24],*A=ss,*B=ss;
    IL char gc()
    {
        return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
    }
    template<class T> void read(T &x)
    {
        rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; 
    }
    char sr[1<<24],z[20]; ll Z,C1=-1;
    template<class T>void wer(T x)
    {
        if (x<0) sr[++C1]='-',x=-x;
        while (z[++Z]=x%10+48,x/=10);
        while (sr[++C1]=z[Z],--Z);
    }
    IL void wer1()
    {
        sr[++C1]=' ';
    }
    IL void wer2()
    {
        sr[++C1]='\n';
    }
    template<class T>IL void maxa(T &x,T y) {if (x<y) x=y;}
    template<class T>IL void mina(T &x,T y) {if (x>y) x=y;} 
    template<class T>IL T MAX(T x,T y){return x>y?x:y;}
    template<class T>IL T MIN(T x,T y){return x<y?x:y;}
};
using namespace IO;
double ee=1.00000000000000;
double pi=acos(-1.0);
const int mo=1e9+7;
const int N=4e5+10;
int r[N],n,m,l,inv[N],mo2,mo3;
int x1[N],x2[N];
struct cp{
    double a,b;
    IL cp operator +(const cp &o) const
    {
        return (cp){a+o.a,b+o.b};
    }
    IL cp operator -(const cp &o) const
    {
        return (cp){a-o.a,b-o.b};
    }
    IL cp operator *(register const cp &o) const
    {
        return (cp){a*o.a-b*o.b,o.a*b+o.b*a};
    }
}a[N],b[N],w[N];
IL int fsp(int x,int y)
{
    ll now=1;
    while (y)
    {
        if (y&1) now=now*x%mo;
        x=1ll*x*x%mo;
        y>>=1;
    }
    return now;
}
IL void clear()
{
    for (int i=0;i<=n;i++) a[i].a=a[i].b=b[i].a=b[i].b=0;
}
IL void fft_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) w[i]=(cp){std::cos(pi*i/n),std::sin(pi*i/n)};
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
IL void fft(cp *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (rint i=1;i<n;i<<=1)
        for (rint j=0;j<n;j+=(i*2))
        {
          register cp *x1=a+j,*x2=a+i+j;
          for (rint k=0;k<i;k++)
          {
              cp W=w[n/i*k]; W.b*=o;
              cp x=x1[k],y=W*x2[k];
              x1[k]=x+y; x2[k]=x-y;
          }
        }
    if (o==-1) for (int i=0;i<n;i++) a[i].a/=ee*n;
}
IL void getcj(cp *A,cp *B,int len)
{
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    m=len*2;
    fft_init();
    fft(a,1); fft(b,1);
    for (int i=0;i<n;i++) a[i]=a[i]*b[i];
    fft(a,-1);
    for (int i=0;i<len;i++) B[i].a=(ll)(a[i].a+0.5)%mo;
    clear();
}
cp a1[N]={},a2[N]={},a3[N]={},b1[N]={},b2[N]={};
IL void getcj2(int *A,int *B,int len)
{
    for (int i=0;i<len;i++)
    {
      a1[i].a=A[i]/mo2,a2[i].a=A[i]%mo2;
      b1[i].a=B[i]/mo2,b2[i].a=B[i]%mo2;
  }
    int now[N]={};
 /*   getcj(b1,a1,len);
    getcj(a2,b1,len);
    getcj(b2,a3,len);
    getcj(a2,b2,len); */
    m=len*2; fft_init();
    fft(b1,1); fft(a1,1); fft(a2,1); fft(b2,1);
    for (int i=0;i<n;i++) a3[i]=a2[i]*b1[i]+b2[i]*a1[i];
    for (int i=0;i<n;i++) b2[i]=a2[i]*b2[i];
    for (int i=0;i<n;i++) a1[i]=a1[i]*b1[i];
    fft(a1,-1); fft(a3,-1); fft(b2,-1);
    for (int i=0;i<len;i++)
    { 
      now[i]=((ll)(a1[i].a+0.5)%mo*mo3+mo2*(((ll)(a3[i].a+0.5))%mo)+(ll)(b2[i].a+0.5))%mo;
      if (now[i]<0) now[i]+=mo;
    }
    for (int i=0;i<n;i++) a1[i].a=a1[i].b=a2[i].a=a2[i].b=b1[i].a=b1[i].b=b2[i].a=b2[i].b=a3[i].a=a3[i].b=0;
    for (int i=0;i<len;i++) B[i]=now[i];
}
void getinv(int *A,int *B,int len)
{
    if (len==1) { B[0]=fsp(A[0],mo-2); return;};
    getinv(A,B,(len+1)/2);
    int C[N]={};
    rep(i,0,len-1) C[i]=A[i];
    getcj2(B,C,len);
    getcj2(B,C,len);
  for (int i=0;i<len;i++) B[i]=((2ll*B[i]-C[i])%mo+mo)%mo;
}
int main()
{
    int n1;
    read(n1); mo2=sqrt(mo); mo3=mo2*mo2;
    rep(i,0,n1-1) read(x1[i]);
    getinv(x1,x2,n1);
    rep(i,0,n1-1) x2[i]=(x2[i]+mo)%mo;
    rep(i,0,n1-1) wer(x2[i]),wer1();
/*    getcj2(x1,x2,n1);
    wer2();
    rep(i,0,n1-1) wer(x2[i]),wer1(); */
    fwrite(sr,1,C1+1,stdout);
    return 0;
}

本来不想改模板的,被两道mtt的题卡常乐了,发现mx的在这里的应用既好写又跑的快

并不太理解原理就背板子了。。

两次都是dft 不是idft

然后大概意思就是把实部作余数,虚部作除数了

// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include <bits/stdc++.h>
using namespace std;
#define rint register int
#define IL inline
#define rep(i,h,t) for(int i=h;i<=t;i++)
#define dep(i,t,h) for(int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define mep(x,y) memcpy(x,y,sizeof(y))
#define mid (t<=0?(h+t-1)/2:(h+t)/2)
namespace IO{
    char ss[1<<24],*A=ss,*B=ss;
    IL char gc()
    {
        return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
    }
    template<class T> void read(T &x)
    {
        rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; 
    }
    char sr[1<<24],z[20]; ll Z,C1=-1;
    template<class T>void wer(T x)
    {
        if (x<0) sr[++C1]='-',x=-x;
        while (z[++Z]=x%10+48,x/=10);
        while (sr[++C1]=z[Z],--Z);
    }
    IL void wer1()
    {
        sr[++C1]=' ';
    }
    IL void wer2()
    {
        sr[++C1]='\n';
    }
    template<class T>IL void maxa(T &x,T y) {if (x<y) x=y;}
    template<class T>IL void mina(T &x,T y) {if (x>y) x=y;} 
    template<class T>IL T MAX(T x,T y){return x>y?x:y;}
    template<class T>IL T MIN(T x,T y){return x<y?x:y;}
};
using namespace IO;
double ee=1.00000000000000;
double pi=std::acos(-1.0);
const int N=4e5+10;
int r[N],n,m,l,inv[N],mo;
int x1[N],x2[N];
struct cp{
    double a,b;
    cp operator +(const cp &o) const
    {
        return (cp){a+o.a,b+o.b};
    }
    cp operator -(const cp &o) const
    {
        return (cp){a-o.a,b-o.b};
    }
    cp operator *(const cp &o) const
    {
        return (cp){a*o.a-b*o.b,o.a*b+o.b*a};
    }
}a[N],b[N],c[N],d[N],w[N];
IL int fsp(int x,int y)
{
    ll now=1;
    while (y)
    {
        if (y&1) now=now*x%mo;
        x=1ll*x*x%mo;
        y>>=1;
    }
    return now;
}
IL void clear()
{
    for (int i=0;i<=n;i++) a[i].a=a[i].b=b[i].a=b[i].b=0;
}
IL void fft_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) w[i]=(cp){std::cos(pi*i/n),std::sin(pi*i/n)};
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
void fft(cp *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i<<=1)
        for (int j=0;j<n;j+=(i*2))
        {
          cp *x1=a+j,*x2=a+i+j;
          for (int k=0;k<i;k++,x1++,x2++)
          {
              cp W=w[n/i*k]; W.b*=o;
              cp x=*x1,y=W*(*x2);
              *x1=x+y,*x2=x-y;
          }
        }
    if (o==-1) for(int i=0;i<n;i++) a[i].a/=n;
}
cp a1[N]={},a2[N]={},a3[N]={},b1[N]={},b2[N]={};
IL void getcj(int *A,int *B,int len)
{
    for (int i=0;i<len;i++)
    {
       a[i]=(cp){A[i]&32767,A[i]>>15};
       b[i]=(cp){B[i]&32767,B[i]>>15};
    }
    m=len; fft_init();
    fft(a,1); fft(b,1);
    for (int i=0;i<n;i++)
    {
        int j=(n-1)&(n-i);
        c[j]=(cp){0.5*(a[i].a+a[j].a),0.5*(a[i].b-a[j].b)}*b[i];
        d[j]=(cp){0.5*(a[i].b+a[j].b),0.5*(a[j].a-a[i].a)}*b[i];
    }
    fft(c,1); fft(d,1);
    double inv=ee/n;
    rep(i,0,n) c[i].a*=inv,c[i].b*=inv;
    rep(i,0,n) d[i].a*=inv,d[i].b*=inv;
    rep(i,0,len)
    {
        ll a1=c[i].a+0.5,a2=c[i].b+0.5;
        ll a3=d[i].a+0.5,a4=d[i].b+0.5;
        B[i]=(a1+((a2+a3)<<15)+((a4%mo)<<30))%mo;
    }
}
int main()
{
    freopen("1.in","r",stdin);
    freopen("1.out","w",stdout);
    int n1,m1;
    read(n1); read(m1); read(mo);
    rep(i,0,n1) read(x1[i]);
    rep(i,0,m1) read(x2[i]);
    getcj(x1,x2,n1+m1+1);
    rep(i,0,n1+m1) wer(x2[i]),wer1();
    fwrite(sr,1,C1+1,stdout);
    return 0;
}

 

#include <bits/stdc++.h>
using namespace std;
#define rint register int
#define IL inline
#define rep(i,h,t) for(int i=h;i<=t;i++)
#define dep(i,t,h) for(int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define mep(x,y) memcpy(x,y,sizeof(y))
#define mid (t<=0?(h+t-1)/2:(h+t)/2)
namespace IO{
    char ss[1<<24],*A=ss,*B=ss;
    IL char gc()
    {
        return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
    }
    template<class T> void read(T &x)
    {
        rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
        while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; 
    }
    char sr[1<<24],z[20]; ll Z,C1=-1;
    template<class T>void wer(T x)
    {
        if (x<0) sr[++C1]='-',x=-x;
        while (z[++Z]=x%10+48,x/=10);
        while (sr[++C1]=z[Z],--Z);
    }
    IL void wer1()
    {
        sr[++C1]=' ';
    }
    IL void wer2()
    {
        sr[++C1]='\n';
    }
    template<class T>IL void maxa(T &x,T y) {if (x<y) x=y;}
    template<class T>IL void mina(T &x,T y) {if (x>y) x=y;} 
    template<class T>IL T MAX(T x,T y){return x>y?x:y;}
    template<class T>IL T MIN(T x,T y){return x<y?x:y;}
};
using namespace IO;
double ee=1.00000000000000;
double pi=acos(-1.0);
const int mo=1e9+7;
const int N=4e5+10;
int r[N],n,m,l,inv[N];
int x1[N],x2[N];
struct cp{
    double a,b;
    IL cp operator +(const cp &o) const
    {
        return (cp){a+o.a,b+o.b};
    }
    IL cp operator -(const cp &o) const
    {
        return (cp){a-o.a,b-o.b};
    }
    IL cp operator *(register const cp &o) const
    {
        return (cp){a*o.a-b*o.b,o.a*b+o.b*a};
    }
}a[N],b[N],w[N],c[N],d[N];
IL int fsp(int x,int y)
{
    ll now=1;
    while (y)
    {
        if (y&1) now=now*x%mo;
        x=1ll*x*x%mo;
        y>>=1;
    }
    return now;
}
IL void clear()
{
    for (int i=0;i<=n;i++) a[i].a=a[i].b=b[i].a=b[i].b=c[i].a=c[i].b=d[i].a=d[i].b=0;
}
IL void fft_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) w[i]=(cp){std::cos(pi*i/n),std::sin(pi*i/n)};
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
void fft(cp *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i<<=1)
        for (int j=0;j<n;j+=(i*2))
        {
          cp *x1=a+j,*x2=a+i+j;
          for (int k=0;k<i;k++,x1++,x2++)
          {
              cp W=w[n/i*k]; W.b*=o;
              cp x=*x1,y=W*(*x2);
              *x1=x+y,*x2=x-y;
          }
        }
    if (o==-1) for(int i=0;i<n;i++) a[i].a/=n;
}
IL void getcj(int *A,int *B,int len)
{
    for (int i=0;i<len;i++)
    {
       a[i]=(cp){A[i]&32767,A[i]>>15};
       b[i]=(cp){B[i]&32767,B[i]>>15};
    }
    m=len*2; fft_init();
    fft(a,1); fft(b,1);
    for (int i=0;i<n;i++)
    {
        int j=(n-1)&(n-i);
        c[j]=(cp){0.5*(a[i].a+a[j].a),0.5*(a[i].b-a[j].b)}*b[i];
        d[j]=(cp){0.5*(a[i].b+a[j].b),0.5*(a[j].a-a[i].a)}*b[i];
    }
    fft(c,1); fft(d,1);
    double inv=ee/n;
    rep(i,0,n) c[i].a*=inv,c[i].b*=inv;
    rep(i,0,n) d[i].a*=inv,d[i].b*=inv;
    rep(i,0,len)
    {
        ll a1=c[i].a+0.5,a2=c[i].b+0.5;
        ll a3=d[i].a+0.5,a4=d[i].b+0.5;
        B[i]=(a1+((a2+a3)<<15)+((a4%mo)<<30))%mo;
    }
    clear();
}
void getinv(int *A,int *B,int len)
{
    if (len==1) { B[0]=fsp(A[0],mo-2); return;};
    getinv(A,B,(len+1)/2);
    int C[N]={};
    rep(i,0,len-1) C[i]=A[i];
    getcj(B,C,len);
    getcj(B,C,len);
    for (int i=0;i<len;i++) B[i]=((2ll*B[i]-C[i])%mo+mo)%mo;
}
int main()
{
    int n1;
    read(n1);
    rep(i,0,n1-1) read(x1[i]);
    getinv(x1,x2,n1);
    rep(i,0,n1-1) x2[i]=(x2[i]+mo)%mo;
    rep(i,0,n1-1) wer(x2[i]),wer1();
    fwrite(sr,1,C1+1,stdout);
    return 0;
}

 

总结:

不要把len和n写错

不要忘记初始化

下一层是(len+1)/2

posted @ 2018-12-05 19:39  尹吴潇  阅读(753)  评论(0编辑  收藏  举报