不闻不若闻|

vanueber

园龄:2年6个月粉丝:0关注:2

基于NTT的高精度乘法封装

#include <bits/stdc++.h>
#define int long long
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define il inline
#define rg register

using namespace std;

inline int read()
{
    int w=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        w=(w<<1)+(w<<3)+(ch^48);
        ch=getchar();
    }
    return w*f;
}
void write(int x)
{
    if(x<0) x=-x;
    if(x>9) write(x/10);
    putchar(x%10+'0');
}

il int qpow(int a,int b,const int mod)
{
    rg int res=1;
    while(b)
    {
        if(b&1) res=res*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return res;
}
namespace poly
{
    const int g=114514,mod=998244353;
    const int inv_g=qpow(g,mod-2,mod);
    #define vint vector<int>
    #define sz(a) ((int)a.size())
    vint split(vint a,int sz)
    {
        a.resize(sz);
        return a;
    }
    void print(vint f)
    {
        for(int i:f)
        {
            cout<<i<<" ";
        }
        cout<<endl;
    }
    void NTT(vint& f)
    {
        int n=sz(f);
        if(n==1) return;
        vint f1(n/2),f2(n/2);
        for(rg int i=0;i<n/2;++i)
        {
            f1[i]=f[i<<1],f2[i]=f[i<<1|1];
        }
        NTT(f1),NTT(f2);
        int g_i=qpow(g,(mod-1)/n,mod),gk=1;
        for(int i=0;i<sz(f2);++i)
        {
            (f[i]=(f1[i]+f2[i]*gk)%mod)%=mod;
            (f[i+n/2]=(mod+(f1[i]-f2[i]*gk)%mod)%mod)%=mod;
            gk=g_i*gk%mod;
        }
    }
    void INTT(vint& f)
    {
        NTT(f);
        auto it=f.begin();
        ++it;
        reverse(it,f.end());
    }
    vint operator *(vint a,vint b)
    {
        int n=a.size(),m=b.size();
        vint res(n+m-1);
        int len=1;
        while(len<n+m-1) len<<=1;
        int inv_len=qpow(len,mod-2,mod); 
        a.resize(len),b.resize(len);
        NTT(a),NTT(b);
        for(int i=0;i<len;++i)
        {
            a[i]=a[i]*b[i]%mod;
        }
        
        INTT(a);
        for(int i=0;i<=n+m-2;++i)
        {
            res[i]=a[i]*inv_len%mod;
        }
        return res;
    }
    vint operator +(const vint &a,const vint &b)
    {
        int len=max(sz(a),sz(b));
        vint res(len);
        for(rg int i=0;i<len;++i)
        {
            res[i]=(a[i]+b[i])%mod;
        }
        return res;
    }
    vint operator -(const vint &a,const vint &b)
    {
        int len=max(sz(a),sz(b));
        vint res(len);
        for(rg int i=0;i<len;++i)
        {
            res[i]=((a[i]-b[i])%mod+mod)%mod;
        }
        return res;
    }
    vint operator *(const vint &a,int b)
    {
        vint res(a.size());
        for(rg int i=0;i<sz(a);++i)
        {
            res[i]=a[i]*b%mod;
        }
        return res;
    }
    vint inv(vint a)
    {
        int n=sz(a);
        if(n==1) return vint{qpow(a[0],mod-2,mod)};
        vint b=inv(split(a,(n+1)>>1));
        vint p1=b*2,p2=a*b*b;
        return split(p1-p2,n);
    }
}
using namespace poly;
#define bint vector<int>
#define uint unsigned int
namespace big_num
{
    bint to_big_int(string ss)
    {
        bint res(ss.size());
        for(int i=ss.size()-1;i>=0;--i)
        {
            res[ss.size()-i-1]=ss[i]-'0';
        }
        return res;
    }
    bint flat(bint a)
    {
        bint res;
        uint i=0,t=0;
        while(i<a.size()||t!=0)
        {
            if(i<a.size())
            {
                t=t+a[i];
            } 
            res.push_back(t%10);
            t/=10;
            ++i;
        }
        while(res.size()>1&&res.back()==0) res.pop_back();
        return res;
    }
    void print(bint a)
    {
        for(int i=a.size()-1;i>=0;--i) cout<<a[i];
    }
    bool operator >=(const bint a,const bint b)
    {
        if(a.size()!=b.size()) return a.size()>b.size();
        for(int i=a.size()-1;i>=0;--i)
        {
            if(a[i]!=b[i]) return a[i]>b[i];
        }
        return 1;
    }
    bint operator +(bint a,bint b)
    {
        bint res;
        int t=0;
        for(uint i=0;i<a.size()||i<b.size();++i)
        {
            if(i<a.size()) t+=a[i];
            if(i<b.size()) t+=b[i];
            res.push_back(t%10);
            t/=10;
        }
        if(t) res.push_back(1);
        return res;
    }
    bint operator -(bint a,bint b)//a>=b
    {
        bint res;
        int t=0;
        for(uint i=0;i<a.size();++i)
        {
            t=a[i]-t;
            if(i<b.size()) t-=b[i];
            res.push_back((t+10)%10);
            if(t<0) t=1;
            else t=0;
        }
        while(res.size()>1&&res.back()==0) res.pop_back();
        return res;
    }
    bint operator *(bint a,bint b)
    {
        bint res;
        res=poly::operator*(a,b);
        res=flat(res);
        return res;
    }
}
using namespace big_num;
bint A,B;
string a,b;

signed main()
{
}

本文作者:vanueber

本文链接:https://www.cnblogs.com/vanueber/p/18679575

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   vanueber  阅读(7)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起