多项式封装
推销一下基于继承的多项式封装,好多函数可以直接用vector的,不用再次封装了,省事很多
另外()运算符太赞了,虽然时间是resize()的\(4\)倍,但是很好用!!,再也不用费力计算多项式的大小了!!!
效率还不错,基本上跑得比大部分取模NTT快,但是比Muel_imj的不取模NTT慢(反向引个流)
全(半)家桶
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
const int N = 4e5+10;
int rev[N];
ll qpow(ll a, ll b){
ll ret = 1;
for(; b ; b >>= 1){
if(b & 1) ret = ret *a % MOD;
a = a * a % MOD;
}
return ret;
}
ll ginv(ll x){
return qpow(x, MOD - 2);
}
ll add(ll x, ll y){
x += y;
if(x >= MOD) return x - MOD;
return x;
}
ll sub(ll x, ll y){
x -= y;
if(x < 0) return x + MOD;
return x;
}
ll w[N];
int preNTT(int len){
int deg = 1;
while(deg < len) deg *= 2;
for(int i = 0; i < deg; ++i)
rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? deg / 2 : 0);
w[0] = 1;
w[1] = qpow(3, (MOD - 1) / deg);
for(int i = 2; i < deg; ++i) w[i] = w[i - 1] * w[1]% MOD;
return deg;
}
struct poly : vector<ll>{
using vector ::vector;
using vector :: operator [];
void cksize(int x){
if(size() < x) resize(x);
}
ll& operator ()(int x){
cksize(x + 1);
return this->operator[](x);
}
friend poly diff(poly f){
for(int i = 0; i + 1 < f.size(); ++i){
f[i] = f[i + 1] * (i + 1) % MOD;
}
f.pop_back();
return f;
}
friend poly inte(poly f){
if(f.empty()) return {};
for(int i = int(f.size()) - 1; i >= 1; --i){
f[i] = f[i - 1] * ginv(i) % MOD;
}
f[0] = 0;
return f;
}
void makemod(){
for(auto & it : (*this)){
it = (it % MOD+ MOD) % MOD;
}
}
friend void NTT(poly &f, int deg, int opt){
f.resize(deg);
// f.ckmod();
for(int i = 0; i < deg; ++i){
if(i < rev[i]) std::swap(f[i], f[rev[i]]);
}
for(int h = 2, m = 1, t = deg / 2; h <= deg; h *= 2, m *= 2, t /= 2){
for(int l = 0; l < deg; l += h){
for(int i = l, j = 0; i < l + m; ++j,++i){
ll x = f[i], y = w[t * j] * f[i + m] % MOD;
f[i] = add(x, y);
f[i + m] = sub(x, y);
}
}
}
f.makemod();
if(opt == -1){
reverse(f.begin() + 1, f.end());
ll iv = ginv(deg);
for(auto &it : f){
it = it * iv % MOD;
}
}
}
friend poly dmul(poly f,const poly& g){
f.cksize(g.size());
for(int i = 0; i < g.size(); ++i){
f[i] = f[i] * g[i] % MOD;
}
return f;
}
friend poly operator - (poly f, const poly& g){
f.cksize(g.size());
for(int i = 0; i < g.size(); ++i){
f[i] = (f[i] - g[i] + MOD) % MOD;
}
return f;
}
friend poly operator *(poly f, poly g){
if(f.empty()|| g.empty()) return {};
int len = f.size() + g.size() - 1;
int deg = preNTT(len);
NTT(f, deg, 1);
NTT(g, deg, 1);
f = dmul(std::move(f), g);
NTT(f, deg, -1);
f.resize(len);
return f;
}
friend poly pinv(const poly& f){
if(f.empty()) return {};
poly ret;
ret(0) = ginv(f[0]);
poly a;
for(int len = 2; len < (2 * f.size()); len *= 2){
a.assign(f.begin(), f.begin() + min(len, (int)f.size()));
int deg = preNTT(a.size() + 2 * ret.size() - 2);
NTT(ret, deg, 1);
NTT(a, deg, 1);
for(int i = 0; i < deg; ++i)
ret[i] = (2 - a[i] * ret[i] % MOD) * ret[i] % MOD;
ret.makemod();
NTT(ret, deg, -1);
ret.resize(len); // to mod
}
ret.resize(f.size());
return ret;
}
friend poly sqrt(const poly& f){
poly res;
res(0) = 1;
poly a;
ll iv2 = ginv(2);
for(int len = 2; len < 2 * f.size(); len *= 2){
res.resize(len); //to ensure enough space & inv's mod is len
a.assign(f.begin(), f.begin() + min(len, (int)f.size()));
a = a * pinv(res);
for(int i = 0; i < len; ++i) res[i] = (res[i] + a[i]) * iv2 % MOD;
}
res.resize(f.size());
return res;
}
void prt()const{
for(auto it : (*this)){
cerr << it <<" ";
}
cerr<<endl;
}
friend poly ln(const poly& f){
poly res = inte(diff(f) * pinv(f));
res.resize(f.size());
return res;
}
friend poly exp(const poly& f){
poly ret;
ret(0) = 1;
poly a, b;
for(int len = 2; len < f.size() * 2; len *= 2){
ret.resize(len); // to ensure INV's mod is len
a = ln(ret);
b.assign(f.begin(), f.begin() + min(len, (int)f.size()));
b = b - a;
b[0] ++ ;
ret = ret * b;
ret.resize(len); // to mod len
}
ret.resize(f.size());
return ret;
}
}f;
int read(){
int x = 0;
char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return x;
}
int main(){
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n;
cin >> n;
f.resize(n);
for(int i = 0; i < n; ++i){
cin >> f[i];
}
f = exp(f);
for(int i = 0; i < n; ++i){
cout << f[i] <<" ";
}
cout <<endl;
return 0;
}
这份NTT 使用了比较简单得优化:
在线计算单位根
点值相乘从std::move变为了引用
luogu 1.48s
粗略优化的NTT
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
const int N = 4e6+10;
int rev[N];
ll qpow(ll a, ll b){
ll ret = 1;
for(; b ; b >>= 1){
if(b & 1) ret = ret *a % MOD;
a = a * a % MOD;
}
return ret;
}
ll ginv(ll x){
return qpow(x, MOD - 2);
}
ll add(ll x, ll y){
x += y;
if(x >= MOD) return x - MOD;
return x;
}
ll sub(ll x, ll y){
x -= y;
if(x < 0) return x + MOD;
return x;
}
int preNTT(int len){
int deg = 1;
while(deg < len) deg *= 2;
for(int i = 0; i < deg; ++i)
rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? deg / 2 : 0);
return deg;
}
struct poly : vector<ll>{
using vector ::vector;
using vector :: operator [];
void cksize(int x){
if(size() < x) resize(x);
}
ll& operator ()(int x){
cksize(x + 1);
return this->operator[](x);
}
friend poly diff(poly f){
for(int i = 0; i + 1 < f.size(); ++i){
f[i] = f[i + 1] * (i + 1) % MOD;
}
f.pop_back();
return f;
}
friend poly inte(poly f){
if(f.empty()) return {};
for(int i = int(f.size()) - 1; i >= 1; --i){
f[i] = f[i - 1] * ginv(i) % MOD;
}
f[0] = 0;
return f;
}
void makemod(){
for(auto & it : (*this)){
it = (it % MOD+ MOD) % MOD;
}
}
friend void NTT(poly &f, int deg, int opt){
f.resize(deg);
for(int i = 0; i < deg; ++i){
if(i < rev[i]) std::swap(f[i], f[rev[i]]);
}
for(int h = 2, m = 1; h <= deg; h *= 2, m *= 2){
ll w1 = qpow(3, (MOD - 1) / h);
for(int l = 0; l < deg; l += h){
ll w0 = 1;
for(int i = l, j = 0; i < l + m; ++j, ++i){
ll x = f[i], y = w0 * f[i + m] % MOD;
f[i] = add(x, y);
f[i + m] = sub(x, y);
w0 = w0 * w1 %MOD;
}
}
}
f.makemod();
if(opt == -1){
reverse(f.begin() + 1, f.end());
ll iv = ginv(deg);
for(auto &it : f){
it = it * iv % MOD;
}
}
}
friend void dmul(poly& f, const poly& g){
for(int i = 0 ; i < g.size(); ++i){
f[i] = f[i] * g[i] % MOD;
}
}
friend poly operator - (poly f, const poly& g){
f.cksize(g.size());
for(int i = 0; i < g.size(); ++i){
f[i] = (f[i] - g[i] + MOD) % MOD;
}
return f;
}
friend poly operator *(poly f, poly g){
if(f.empty()|| g.empty()) return {};
int len = f.size() + g.size() - 1;
int deg = preNTT(len);
NTT(f, deg, 1);
NTT(g, deg, 1);
dmul(f, g);
NTT(f, deg, -1);
f.resize(len);
return f;
}
}f;
int read(){
int x = 0;
char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return x;
}
int main(){
int n, m;
n = read(), m = read();
++n, ++m;
poly f, g;
f.resize(n), g.resize(m);
for(int i = 0; i < n; ++i) f[i] = read();
for(int i = 0; i < m; ++i) g[i] = read();
f = f * g;
for(auto it : f){
printf("%lld ", it);
}
printf("\n");
return 0;
}
这份NTT的多项式部分从 long long 改为int
luogu 1.32s
进一步优化的NTT
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
const int N = 4e6+10;
int rev[N];
int qpow(int a, int b){
int ret = 1;
for(; b ; b >>= 1){
if(b & 1) ret = 1llu * ret * a % MOD;
a = 1llu * a * a % MOD;
}
return ret;
}
int ginv(int x){
return qpow(x, MOD - 2);
}
int add(int x, int y){
x += y;
if(x >= MOD) return x - MOD;
return x;
}
int sub(int x, int y){
if(x < y) return x + MOD- y;
return x - y;
}
int preNTT(int len){
int deg = 1;
while(deg < len) deg *= 2;
for(int i = 0; i < deg; ++i)
rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? deg / 2 : 0);
return deg;
}
struct poly : vector<int>{
using vector ::vector;
using vector :: operator [];
friend void NTT(poly &f, int deg, int opt){
f.resize(deg);
for(int i = 0; i < deg; ++i){
if(i < rev[i]) std::swap(f[i], f[rev[i]]);
}
for(int h = 2, m = 1; h <= deg; h *= 2, m *= 2){
int w1 = qpow(3, (MOD - 1) / h);
for(int l = 0; l < deg; l += h){
int w0 = 1;
for(int i = l, j = 0; i < l + m; ++j, ++i){
int x = f[i], y = 1ll * w0 * f[i + m] % MOD;
f[i] = x + y > MOD ? x + y - MOD : x + y;
f[i + m] = x < y ? MOD + x - y : x - y;
w0 = 1ll * w0 * w1 % MOD;
}
}
}
if(opt == -1){
reverse(f.begin() + 1, f.end());
ll iv = ginv(deg);
for(auto &it : f){
it = it * iv % MOD;
}
}
}
friend void dmul(poly& f, const poly& g){
for(int i = 0 ; i < g.size(); ++i){
f[i] = 1ll * f[i] * g[i] % MOD;
}
}
friend poly operator *(poly f, poly g){
if(f.empty()|| g.empty()) return {};
int len = f.size() + g.size() - 1;
int deg = preNTT(len);
NTT(f, deg, 1);
NTT(g, deg, 1);
dmul(f, g);
NTT(f, deg, -1);
f.resize(len);
return f;
}
friend void operator *=(poly& f, poly g){
if(f.empty()|| g.empty()) {
f = {};
return;
}
int len = f.size() + g.size() - 1;
int deg = preNTT(len);
NTT(f, deg, 1);
NTT(g, deg, 1);
dmul(f, g);
NTT(f, deg, -1);
f.resize(len);
}
}f;
int read(){
int x = 0;
char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return x;
}
int main(){
int n, m;
n = read(), m = read();
++n, ++m;
poly f, g;
f.resize(n), g.resize(m);
for(int i = 0; i < n; ++i) f[i] = read();
for(int i = 0; i < m; ++i) g[i] = read();
f = f * g;
for(auto it : f){
printf("%d ", it);
}
printf("\n");
return 0;
}