基于NTT的高精度乘法封装
#include <bits/stdc++.h>
#define int long long
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define il inline
#define rg register
using namespace std;
inline int read()
{
int w=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
w=(w<<1)+(w<<3)+(ch^48);
ch=getchar();
}
return w*f;
}
void write(int x)
{
if(x<0) x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
il int qpow(int a,int b,const int mod)
{
rg int res=1;
while(b)
{
if(b&1) res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}
namespace poly
{
const int g=114514,mod=998244353;
const int inv_g=qpow(g,mod-2,mod);
#define vint vector<int>
#define sz(a) ((int)a.size())
vint split(vint a,int sz)
{
a.resize(sz);
return a;
}
void print(vint f)
{
for(int i:f)
{
cout<<i<<" ";
}
cout<<endl;
}
void NTT(vint& f)
{
int n=sz(f);
if(n==1) return;
vint f1(n/2),f2(n/2);
for(rg int i=0;i<n/2;++i)
{
f1[i]=f[i<<1],f2[i]=f[i<<1|1];
}
NTT(f1),NTT(f2);
int g_i=qpow(g,(mod-1)/n,mod),gk=1;
for(int i=0;i<sz(f2);++i)
{
(f[i]=(f1[i]+f2[i]*gk)%mod)%=mod;
(f[i+n/2]=(mod+(f1[i]-f2[i]*gk)%mod)%mod)%=mod;
gk=g_i*gk%mod;
}
}
void INTT(vint& f)
{
NTT(f);
auto it=f.begin();
++it;
reverse(it,f.end());
}
vint operator *(vint a,vint b)
{
int n=a.size(),m=b.size();
vint res(n+m-1);
int len=1;
while(len<n+m-1) len<<=1;
int inv_len=qpow(len,mod-2,mod);
a.resize(len),b.resize(len);
NTT(a),NTT(b);
for(int i=0;i<len;++i)
{
a[i]=a[i]*b[i]%mod;
}
INTT(a);
for(int i=0;i<=n+m-2;++i)
{
res[i]=a[i]*inv_len%mod;
}
return res;
}
vint operator +(const vint &a,const vint &b)
{
int len=max(sz(a),sz(b));
vint res(len);
for(rg int i=0;i<len;++i)
{
res[i]=(a[i]+b[i])%mod;
}
return res;
}
vint operator -(const vint &a,const vint &b)
{
int len=max(sz(a),sz(b));
vint res(len);
for(rg int i=0;i<len;++i)
{
res[i]=((a[i]-b[i])%mod+mod)%mod;
}
return res;
}
vint operator *(const vint &a,int b)
{
vint res(a.size());
for(rg int i=0;i<sz(a);++i)
{
res[i]=a[i]*b%mod;
}
return res;
}
vint inv(vint a)
{
int n=sz(a);
if(n==1) return vint{qpow(a[0],mod-2,mod)};
vint b=inv(split(a,(n+1)>>1));
vint p1=b*2,p2=a*b*b;
return split(p1-p2,n);
}
}
using namespace poly;
#define bint vector<int>
#define uint unsigned int
namespace big_num
{
bint to_big_int(string ss)
{
bint res(ss.size());
for(int i=ss.size()-1;i>=0;--i)
{
res[ss.size()-i-1]=ss[i]-'0';
}
return res;
}
bint flat(bint a)
{
bint res;
uint i=0,t=0;
while(i<a.size()||t!=0)
{
if(i<a.size())
{
t=t+a[i];
}
res.push_back(t%10);
t/=10;
++i;
}
while(res.size()>1&&res.back()==0) res.pop_back();
return res;
}
void print(bint a)
{
for(int i=a.size()-1;i>=0;--i) cout<<a[i];
}
bool operator >=(const bint a,const bint b)
{
if(a.size()!=b.size()) return a.size()>b.size();
for(int i=a.size()-1;i>=0;--i)
{
if(a[i]!=b[i]) return a[i]>b[i];
}
return 1;
}
bint operator +(bint a,bint b)
{
bint res;
int t=0;
for(uint i=0;i<a.size()||i<b.size();++i)
{
if(i<a.size()) t+=a[i];
if(i<b.size()) t+=b[i];
res.push_back(t%10);
t/=10;
}
if(t) res.push_back(1);
return res;
}
bint operator -(bint a,bint b)//a>=b
{
bint res;
int t=0;
for(uint i=0;i<a.size();++i)
{
t=a[i]-t;
if(i<b.size()) t-=b[i];
res.push_back((t+10)%10);
if(t<0) t=1;
else t=0;
}
while(res.size()>1&&res.back()==0) res.pop_back();
return res;
}
bint operator *(bint a,bint b)
{
bint res;
res=poly::operator*(a,b);
res=flat(res);
return res;
}
}
using namespace big_num;
bint A,B;
string a,b;
signed main()
{
}
本文作者:vanueber
本文链接:https://www.cnblogs.com/vanueber/p/18679575
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步