$ O(\text{poly}(n) \log ) $ 多项式(不完整)

补上了 vector 封装的版本

upd on 9.25:预处理了原根。

upd on 10.7 poly 中 ln 之前没清空 2 倍。

upd on 2023.2.23 加了普通版多项式开根。

// O(n \log^2 n) exp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read() {
	int f=1, r=0; char c=getchar();
	while (!isdigit(c)) f^=c=='-', c=getchar();
	while (isdigit(c)) r=(r<<1)+(r<<3)+(c&15), c=getchar();
	return f?r:-r;
}
const int N=1<<19|7, mod=998244353, G=3, invG=332748118;
inline void inc(int &x, int y) {x+=y; if (x>=mod) x-=mod;}
inline void dec(int &x, int y) {x-=y; if (x<0) x+=mod;}
inline void mul(int &x, int y) {x=(ll)x*y%mod;}
inline int add(int x, int y) {return inc(x, y), x;}
inline int sub(int x, int y) {return dec(x, y), x;}
inline int qpow(int a, int b) {
	int res=1;
	for (; b; b>>=1, a=(ll)a*a%mod)
		if (b&1) res=(ll)res*a%mod;
	return res;
}
namespace Poly {
	typedef uint64_t ull;
	int lstn, rev[N], g[N]={1};
	inline void init(int n) {
		if (lstn==n) return; lstn=n;
		for (int i=1; i<n; i++) rev[i]=rev[i>>1]>>1|(i&1?n>>1:0);
	}
	inline void init_g(int n) {
		g[0]=1;
		for (int i=1; i<n; i<<=1) {
			int p=i<<1, t=qpow(G, (mod-1)/p); g[i]=1;
			for (int j=i+1; j<p; j++) g[j]=(ll)g[j-1]*t%mod;
		}
	}
	inline void NTT(int n, int *f, int op) {
		static ull A[N];
		for (int i=0; i<n; i++) A[i]=f[rev[i]];
		for (int len=2, p=1; len<=n; p=len, len<<=1)
			for (int i=0; i<n; i+=len)
				for (int j=0; j<p; j++) {
					int l=i|j, r=l|p, tmp=A[r]*g[p|j]%mod;
					A[r]=A[l]+mod-tmp, A[l]+=tmp;
				}
		for (int i=0; i<n; i++) f[i]=A[i]%mod;
		if (op) return;
		int inv=mod-(mod-1)/n; reverse(f+1, f+n);
		for (int i=0; i<n; i++) mul(f[i], inv);
	}
	inline void px(int n, int *a, int *b) {
		for (int i=0; i<n; i++) mul(a[i], b[i]);
	}
	inline void clear(int n, int *a) {memset(a, 0, sizeof(*a)*n);}
	inline void cpy(int n, int *a, int *b) {memcpy(a, b, sizeof(*a)*n);}
	inline void Mul(int n, int *a, int *b) {
		static int B[N]; cpy(n, B, b);
		init(n), NTT(n, a, 1), NTT(n, B, 1), px(n, a, B), NTT(n, a, 0);
	}
	inline void Inv(int n, int *a) {
		static int b[N], c[N]; clear(n, b), b[0]=qpow(a[0], mod-2);
		for (int len=2, p=1; len<=n; p=len, len<<=1) {
			cpy(len, c, a), Mul(len, c, b), clear(p, c), Mul(len, c, b);
			for (int i=p; i<len; i++) b[i]=(mod-c[i])%mod;
		}
		cpy(n, a, b);
	}
	inline void Der(int n, int *a) {
		for (int i=1; i<n; i++) a[i-1]=(ll)a[i]*i%mod;
		a[n-1]=0;
	}
	int lstn_inv=1, inv[N]={0, 1};
	inline void init_inv(int n) {
		if (lstn_inv>=n) return;
		for (int i=lstn_inv+1; i<=n; i++) inv[i]=mod-(ll)(mod/i)*inv[mod%i]%mod;
		lstn_inv=n;
	}
	inline void Inte(int n, int *a) {
		init_inv(n-1);
		for (int i=n-1; i; i--) a[i]=(ll)a[i-1]*inv[i]%mod;
		a[0]=0;
	}
	inline void Ln(int n, int *a) {
		static int b[N]; for (int i=n; i<n+n; i++) b[i]=0;
		cpy(n, b, a), Der(n, a), Inv(n, b), Mul(n<<1, a, b), Inte(n, a);
	}
	void solve_exp(int l, int n, int *f, int *g) {
		if (n==1) {if (l) f[0]=(ll)f[0]*inv[l]%mod; return;}
		int p=n>>1; solve_exp(l, p, f, g);
		static int h[N]; cpy(p, h, f);
		for (int i=p; i<n; i++) h[i]=0;
		Mul(n, h, g);
		for (int i=p; i<n; i++) inc(f[i], h[i]);
		solve_exp(l+p, p, f+p, g);
	}
	inline void Exp(int n, int *a) {
		for(int i=1; i<n; i++) a[i]=(ll)a[i]*i%mod;
		static int b[N]; clear(n, b);
		b[0]=1, init_inv(n-1), solve_exp(0, n, b, a), cpy(n, a, b);
	}
	inline void Sqrt(int n, int *a) {
		static int b[N], c[N]; b[0]=1;
		for (int len=2, p=1; len<=n; p=len, len<<=1) {
			for (int i=0; i<p; i++) c[i]=add(b[i], b[i]);
			Inv(len, c), NTT(len, b, 1), px(len, b, b), NTT(len, b, 0);
			for (int i=0; i<len; i++) inc(b[i], a[i]);
			Mul(len<<1, b, c);
			for (int i=len; i<len+len; i++) b[i]=0;
		}
		cpy(n, a, b), clear(n, b), clear(n, c);
	}
	int A[N], B[N];
	inline int Get(int m) {int n=1; while (n<m) n<<=1; return n;}
	struct poly {
		vector<int> a;
		poly(int n=0) {a.resize(n);}
		inline int size() {return a.size();}
		inline void resize(int n) {a.resize(n);}
		inline void push_back(int x) {a.push_back(x);}
		inline void pop_back() {a.pop_back();}
		int& operator [](int i) {return a[i];}
		poly& operator +=(poly b) {
			if (a.size()<b.size()) a.resize(b.size());
			for (int i=0; i<b.size(); i++) inc(a[i], b[i]);
			return *this;
		}
		poly& operator -=(poly b) {
			if (a.size()<b.size()) a.resize(b.size());
			for (int i=0; i<b.size(); i++) dec(a[i], b[i]);
			return *this;
		}
		poly& operator *=(poly b) {
			int n=Get(a.size()+b.size()-1);
			for (int i=0; i<(int)a.size(); i++) A[i]=a[i];
			for (int i=a.size(); i<n; i++) A[i]=0;
			for (int i=0; i<b.size(); i++) B[i]=b[i];
			for (int i=b.size(); i<n; i++) B[i]=0;
			Mul(n, A, B), a.resize(a.size()+b.size()-1);
			for (int i=0; i<(int)a.size(); i++) a[i]=A[i];
			return *this;
		}
		poly& operator <<=(int k) {
			reverse(a.begin(), a.end());
			for (int i=0; i<k; i++) a.push_back(0);
			reverse(a.begin(), a.end());
			return *this;
		}
		poly operator +(poly b) {poly tmp=*this; tmp+=b; return tmp;}
		poly operator -(poly b) {poly tmp=*this; tmp-=b; return tmp;}
		poly operator *(poly b) {poly tmp=*this; tmp*=b; return tmp;}
		poly operator <<(int k) {poly tmp=*this; tmp<<=k; return tmp;}
		poly inv() {
			int n=Get(a.size());
			for (int i=0; i<(int)a.size(); i++) A[i]=a[i];
			for (int i=a.size(); i<n; i++) A[i]=0;
			Inv(n, A); poly b(a.size());
			for (int i=0; i<b.size(); i++) b[i]=A[i];
			return b;
		}
		poly der() {
			poly b(a.size()-1);
			for (int i=1; i<(int)a.size(); i++) b[i-1]=(ll)a[i]*i%mod;
			return b;
		}
		poly inte() {
			poly b(a.size()+1); init_inv(a.size());
			for (int i=1; i<b.size(); i++) b[i]=(ll)a[i-1]*Poly::inv[i]%mod;
			return b;
		}
		poly ln() {
			int n=Get(a.size());
			for (int i=0; i<(int)a.size(); i++) A[i]=a[i];
			for (int i=a.size(); i<n+n; i++) A[i]=0;
			Ln(n, A); poly b(a.size());
			for (int i=0; i<b.size(); i++) b[i]=A[i];
			return b;
		}
		poly exp() {
			int n=Get(a.size());
			for (int i=0; i<(int)a.size(); i++) A[i]=a[i];
			for (int i=a.size(); i<n; i++) A[i]=0;
			Exp(n, A); poly b(a.size());
			for (int i=0; i<b.size(); i++) b[i]=A[i];
			return b;
		}
		poly sqrt() {
			int n=Get(a.size());
			for (int i=0; i<(int)a.size(); i++) A[i]=a[i];
			for (int i=a.size(); i<n; i++) A[i]=0;
			Sqrt(n, A); poly b(a.size());
			for (int i=0; i<b.size(); i++) b[i]=A[i];
			return b;
		}
	};
} using Poly::poly; using Poly::NTT;
int n, m, fac[N], ifac[N];
inline void init(int n) {
	fac[0]=1;
	for (int i=1; i<=n; i++) fac[i]=(ll)fac[i-1]*i%mod;
	ifac[n]=qpow(fac[n], mod-2);
	for (int i=n; i; i--) ifac[i-1]=(ll)ifac[i]*i%mod;
}
int main() {
#ifdef LOCAL
	freopen("1.in", "r", stdin);
	freopen("1.out", "w", stdout);
#endif
	Poly::init_g(N-7);
	
	return 0;
}
posted @ 2022-07-17 22:31  b1ts  阅读(150)  评论(0编辑  收藏  举报