多项式封装

推销一下基于继承的多项式封装,好多函数可以直接用vector的,不用再次封装了,省事很多
另外()运算符太赞了,虽然时间是resize()的\(4\)倍,但是很好用!!,再也不用费力计算多项式的大小了!!!
效率还不错,基本上跑得比大部分取模NTT快,但是比Muel_imj的不取模NTT慢(反向引个流)

全(半)家桶
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
const int N = 4e5+10;
int rev[N];
ll qpow(ll a, ll b){
	ll ret = 1;
	for(; b ; b >>= 1){
		if(b & 1) ret = ret *a % MOD;
		a = a * a % MOD;
	}
	return ret;
}
ll ginv(ll x){
	return qpow(x, MOD - 2);
}
ll add(ll x, ll y){
	x += y;
	if(x >= MOD) return x - MOD;
	return x;
}
ll sub(ll x, ll y){
	x -= y;
	if(x < 0) return x + MOD;
	return x;
}
ll w[N];
int preNTT(int len){
	int deg = 1;
	while(deg < len) deg *= 2;
	for(int i = 0; i < deg; ++i)
		rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? deg / 2 : 0);
	w[0] = 1;
	w[1] = qpow(3, (MOD - 1) / deg);
	for(int i = 2; i < deg; ++i) w[i] = w[i - 1] * w[1]% MOD;
	return deg;
}
struct poly : vector<ll>{
	using vector ::vector;
	using vector :: operator [];
	void cksize(int x){
		if(size() < x) resize(x);
	}
	ll& operator ()(int x){
		cksize(x + 1);
		return this->operator[](x);
	}
	friend poly diff(poly f){
		for(int i = 0; i + 1 < f.size(); ++i){
			f[i] = f[i + 1] * (i + 1) % MOD;
		}
		f.pop_back();
		return f;
	}
	friend poly inte(poly f){
		if(f.empty()) return {};
		for(int i = int(f.size()) - 1; i >= 1; --i){
			f[i] = f[i - 1] * ginv(i) % MOD;
		}
		f[0] = 0;
		return f;
	}
	void makemod(){
		for(auto & it : (*this)){
			it = (it % MOD+ MOD) % MOD;
		}
	}
	friend void NTT(poly &f, int deg, int opt){
		f.resize(deg);
		// f.ckmod();
		for(int i = 0; i < deg; ++i){
			if(i < rev[i]) std::swap(f[i], f[rev[i]]);
		}
		for(int h = 2, m = 1, t = deg / 2; h <= deg; h *= 2, m *= 2, t /= 2){
			for(int l = 0; l < deg; l += h){
				for(int i = l, j = 0; i < l + m; ++j,++i){
					ll x = f[i], y = w[t * j] * f[i + m] % MOD;
					f[i] = add(x, y);
					f[i + m] = sub(x, y);
				}
			}
		}
		f.makemod();
		if(opt == -1){
			reverse(f.begin() + 1, f.end());
			ll iv = ginv(deg);
			for(auto &it : f){
				it = it * iv % MOD;
			}
		}
	}
	friend poly dmul(poly f,const poly& g){
		f.cksize(g.size());
		for(int i = 0; i < g.size(); ++i){
			f[i] = f[i] * g[i] % MOD;
		}
		return f;
	}
	friend poly operator - (poly f, const poly& g){
		f.cksize(g.size());
		for(int i = 0; i < g.size(); ++i){
			f[i] = (f[i] - g[i] + MOD) % MOD;
		}
		return f;
	}
	friend poly operator *(poly f, poly g){
		if(f.empty()|| g.empty()) return {};
		int len = f.size() + g.size() - 1;
		int deg = preNTT(len);
		NTT(f, deg, 1);
		NTT(g, deg, 1);
		f = dmul(std::move(f), g);
		NTT(f, deg, -1);
		f.resize(len);
		return f;
	}
	friend poly pinv(const poly& f){
		if(f.empty()) return {};
		poly ret;
		ret(0) = ginv(f[0]);
		poly a;
		for(int len = 2; len < (2 * f.size()); len *= 2){
			a.assign(f.begin(), f.begin() + min(len, (int)f.size()));
			int deg = preNTT(a.size() + 2 * ret.size() - 2);
			NTT(ret, deg, 1);
			NTT(a, deg, 1);
			for(int i = 0; i < deg; ++i)
				ret[i] = (2 - a[i] * ret[i] % MOD) * ret[i] % MOD;
			ret.makemod();
			NTT(ret, deg, -1);
			ret.resize(len); // to mod
		}
		ret.resize(f.size());
		return ret;
	}
	friend poly sqrt(const poly& f){
		poly res;
		res(0) = 1;
		poly a;
		ll iv2 = ginv(2);
		for(int len = 2; len < 2 * f.size(); len *= 2){
			res.resize(len); //to ensure enough space & inv's mod is len
			a.assign(f.begin(), f.begin() + min(len, (int)f.size()));
			a = a * pinv(res);
			for(int i = 0; i < len; ++i) res[i] = (res[i] + a[i]) * iv2 % MOD;
		}
		res.resize(f.size());
		return res;
	}
	void prt()const{
		for(auto it : (*this)){
			cerr << it <<" ";
		}
		cerr<<endl;
	}
	friend poly ln(const poly& f){
		poly res = inte(diff(f) * pinv(f));
		res.resize(f.size());
		return res;
	}
	friend poly exp(const poly& f){
		poly ret;
		ret(0) = 1;
		poly a, b;
		for(int len = 2; len < f.size() * 2; len *= 2){
			ret.resize(len); // to ensure INV's mod is len
			a = ln(ret);
			b.assign(f.begin(), f.begin() + min(len, (int)f.size()));
			b = b - a;
			b[0] ++ ;
			ret = ret * b;
			ret.resize(len); // to mod len
		}
		ret.resize(f.size());
		return ret;
	}
}f;	
int read(){
	int x = 0;
	char ch = getchar();
	while(!isdigit(ch)) ch = getchar();
	while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
	return x;
}
int main(){
	ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
	int n;
	cin >> n;
	f.resize(n);
	for(int i = 0; i < n; ++i){
		cin >> f[i];
	}
	f = exp(f);
	for(int i = 0; i < n; ++i){
		cout << f[i] <<" ";
	}
	cout <<endl;
	return 0;
}

这份NTT 使用了比较简单得优化:
在线计算单位根
点值相乘从std::move变为了引用
luogu 1.48s

粗略优化的NTT
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
const int N = 4e6+10;
int rev[N];
ll qpow(ll a, ll b){
	ll ret = 1;
	for(; b ; b >>= 1){
		if(b & 1) ret = ret *a % MOD;
		a = a * a % MOD;
	}
	return ret;
}
ll ginv(ll x){
	return qpow(x, MOD - 2);
}
ll add(ll x, ll y){
	x += y;
	if(x >= MOD) return x - MOD;
	return x;
}
ll sub(ll x, ll y){
	x -= y;
	if(x < 0) return x + MOD;
	return x;
}
int preNTT(int len){
	int deg = 1;
	while(deg < len) deg *= 2;
	for(int i = 0; i < deg; ++i)
		rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? deg / 2 : 0);
	return deg;
}
struct poly : vector<ll>{
	using vector ::vector;
	using vector :: operator [];
	void cksize(int x){
		if(size() < x) resize(x);
	}
	ll& operator ()(int x){
		cksize(x + 1);
		return this->operator[](x);
	}
	friend poly diff(poly f){
		for(int i = 0; i + 1 < f.size(); ++i){
			f[i] = f[i + 1] * (i + 1) % MOD;
		}
		f.pop_back();
		return f;
	}
	friend poly inte(poly f){
		if(f.empty()) return {};
		for(int i = int(f.size()) - 1; i >= 1; --i){
			f[i] = f[i - 1] * ginv(i) % MOD;
		}
		f[0] = 0;
		return f;
	}
	void makemod(){
		for(auto & it : (*this)){
			it = (it % MOD+ MOD) % MOD;
		}
	}
	friend void NTT(poly &f, int deg, int opt){
		f.resize(deg);
		for(int i = 0; i < deg; ++i){
			if(i < rev[i]) std::swap(f[i], f[rev[i]]);
		}
		for(int h = 2, m = 1; h <= deg; h *= 2, m *= 2){
			ll w1 = qpow(3, (MOD - 1) / h);
			for(int l = 0; l < deg; l += h){
				ll w0 = 1;
				for(int i = l, j = 0; i < l + m; ++j, ++i){
					ll x = f[i], y = w0 * f[i + m] % MOD;
					f[i] = add(x, y);
					f[i + m] = sub(x, y);
					w0 = w0 * w1 %MOD;
				}
			}
		}
		f.makemod();
		if(opt == -1){
			reverse(f.begin() + 1, f.end());
			ll iv = ginv(deg);
			for(auto &it : f){
				it = it * iv % MOD;
			}
		}
	}
	friend void dmul(poly& f, const poly& g){
		for(int i = 0 ; i < g.size(); ++i){
			f[i] = f[i] * g[i] % MOD;
		}
	}
	friend poly operator - (poly f, const poly& g){
		f.cksize(g.size());
		for(int i = 0; i < g.size(); ++i){
			f[i] = (f[i] - g[i] + MOD) % MOD;
		}
		return f;
	}
	friend poly operator *(poly f, poly g){
		if(f.empty()|| g.empty()) return {};
		int len = f.size() + g.size() - 1;
		int deg = preNTT(len);
		NTT(f, deg, 1);
		NTT(g, deg, 1);
		dmul(f, g);
		NTT(f, deg, -1);
		f.resize(len);
		return f;
	}
}f;	
int read(){
	int x = 0;
	char ch = getchar();
	while(!isdigit(ch)) ch = getchar();
	while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
	return x;
}
int main(){
	int n, m;
	n = read(), m = read();
	++n, ++m;
	poly f, g;
	f.resize(n), g.resize(m);
	for(int i = 0; i < n; ++i) f[i] = read();
	for(int i = 0; i < m; ++i) g[i] = read();
	f = f * g;
	for(auto it : f){
		printf("%lld ", it);
	}
	printf("\n");
	return 0;
}

这份NTT的多项式部分从 long long 改为int
luogu 1.32s

进一步优化的NTT
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
const int N = 4e6+10;
int rev[N];
int qpow(int a, int b){
	int ret = 1;
	for(; b ; b >>= 1){
		if(b & 1) ret = 1llu * ret * a % MOD;
		a = 1llu * a * a % MOD;
	}
	return ret;
}
int ginv(int x){
	return qpow(x, MOD - 2);
}
int add(int x, int y){
	x += y;
	if(x >= MOD) return x - MOD;
	return x;
}
int sub(int x, int y){
	if(x < y) return x + MOD- y;
	return x - y;
}
int preNTT(int len){
	int deg = 1;
	while(deg < len) deg *= 2;
	for(int i = 0; i < deg; ++i)
		rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? deg / 2 : 0);
	return deg;
}
struct poly : vector<int>{
	using vector ::vector;
	using vector :: operator [];
	friend void NTT(poly &f, int deg, int opt){
		f.resize(deg);
		for(int i = 0; i < deg; ++i){
			if(i < rev[i]) std::swap(f[i], f[rev[i]]);
		}
		for(int h = 2, m = 1; h <= deg; h *= 2, m *= 2){
			int w1 = qpow(3, (MOD - 1) / h);
			for(int l = 0; l < deg; l += h){
				int w0 = 1;
				for(int i = l, j = 0; i < l + m; ++j, ++i){
					int x = f[i], y = 1ll * w0 * f[i + m] % MOD;
					f[i] = x + y > MOD ? x + y - MOD : x + y;
					f[i + m] = x < y ? MOD + x - y : x - y;
					w0 = 1ll * w0 * w1 % MOD;
				}
			}
		}
		if(opt == -1){
			reverse(f.begin() + 1, f.end());
			ll iv = ginv(deg);
			for(auto &it : f){
				it = it * iv % MOD;
			}
		}
	}
	friend void dmul(poly& f, const poly& g){
		for(int i = 0 ; i < g.size(); ++i){
			f[i] = 1ll * f[i] * g[i] % MOD;	
		}
	}
	friend poly operator *(poly f, poly g){
		if(f.empty()|| g.empty()) return {};
		int len = f.size() + g.size() - 1;
		int deg = preNTT(len);
		NTT(f, deg, 1);
		NTT(g, deg, 1);
		dmul(f, g);
		NTT(f, deg, -1);
		f.resize(len);
		return f;
	}
	friend void operator *=(poly& f, poly g){
		if(f.empty()|| g.empty()) {
			f = {};
			return;
		}
		int len = f.size() + g.size() - 1;
		int deg = preNTT(len);
		NTT(f, deg, 1);
		NTT(g, deg, 1);
		dmul(f, g);
		NTT(f, deg, -1);
		f.resize(len);
	}
}f;	
int read(){
	int x = 0;
	char ch = getchar();
	while(!isdigit(ch)) ch = getchar();
	while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
	return x;
}

int main(){
	int n, m;
	n = read(), m = read();
	++n, ++m;
	poly f, g;
	f.resize(n), g.resize(m);
	for(int i = 0; i < n; ++i) f[i] = read();
	for(int i = 0; i < m; ++i) g[i] = read();
	f = f * g;
	for(auto it : f){
		printf("%d ", it);
	}
	printf("\n");
	return 0;
}
posted @ 2023-01-04 15:15  CDsidi  阅读(59)  评论(0编辑  收藏  举报