快速数论变换(NTT)+多项式全家桶

在系数均为整数的时候,可以用NTT代替FFT,这样不会出现精度问题。

代码
#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int N = 20000005;
const lld g = 3, mod = 998244353;
lld r[N];
lld powe(lld a, lld b) {
    lld base = 1;
    while(b) {
        if(b & 1) base = base * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return base;
}
void init(int M) 
{
    int l = log2(M); r[0] = 0; 
    for(int i = 1; i < M; i++) {
        r[i] = (r[i >> 1] >> 1 | (i & 1) << (l - 1));
    }
}
void NTT(vector <lld> &a, int M, int type) {
    init(M);
	for(int i = 0; i < M; i++) {
        if(i<r[i]) swap(a[i], a[r[i]]);
    }
    lld Wn, w;
    for(int mid = 1; mid < M; mid <<= 1) {
        Wn = powe(g, (mod - 1) / (mid << 1));
        if(type == -1) Wn = powe(Wn, mod - 2);
        int size = mid << 1;
        for(int i = 0; i < M; i += size) {
            int j = i + mid; w = 1;
            for(int k = i; k < j; k++) {
                lld x = a[k], y = w * a[k + mid] % mod;
                a[k] = (x + y) % mod;
                a[k + mid] = (x - y + mod) % mod;
                w = w * Wn % mod;
            }
        }
    }
    if(type == -1) {
        int iv = powe(M, mod - 2);
        for (int i = 0; i < M; i++) {
        	a[i] = a[i] * iv % mod;
        }
    }
}
vector<lld> conv(vector<lld> a, vector<lld> b) {
	int siz = a.size() + b.size();
	int M = 1;
	while(M < siz) M <<= 1;
	a.resize(M); b.resize(M); 
    init(M); NTT(a, M, 1); NTT(b, M, 1);
    for(int i = 0; i <= M; i++) {
        a[i] = a[i] * b[i] % mod;
    }
    NTT(a, M, -1); a.resize(siz);
    return a; 
}
int n, m;
int main() {
    // freopen("data.in", "r", stdin);
    cin >> n >> m;
    vector <lld> a(n + 2), b(m + 2);
    for(int i = 0; i <= n; i++) {
        cin >> a[i];
    }
    for(int i = 0; i <= m; i++) {
        cin >> b[i];
    }
    vector <lld> c = conv(a, b);
    for(int i = 0; i <= n + m; i++) {
        cout << c[i] << " ";
    }
    return 0;
}

多项式乘法,求逆,开根

P5205

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int N = 400010;
const lld g = 3, mod = 998244353;
lld r[N];
lld powe(lld a, lld b) {
    lld base = 1;
    while(b) {
        if(b & 1) base = base * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return base;
}
lld inv2;
void init(int M) {
    int l = log2(M); r[0] = 0; 
    for(int i = 1; i < M; i++) {
        r[i] = (r[i >> 1] >> 1 | (i & 1) << (l - 1));
    }
}
void NTT(vector <lld> &a, int M, int type)
{
    init(M);
	for(int i = 0; i < M; i++) {
        if(i < r[i]) swap(a[i], a[r[i]]);
    }
    lld Wn, w;
    for(int mid = 1; mid < M; mid <<= 1) {
        Wn = powe(g, (mod - 1) / (mid << 1));
        if(type == -1) Wn = powe(Wn, mod - 2);
        int size = mid << 1;
        for(int i = 0; i < M; i += size) {
            int j = i + mid; w = 1;
            for(int k = i; k < j; k++) {
                lld x = a[k], y = w * a[k + mid] % mod;
                a[k] = (x + y) % mod;
                a[k + mid] = (x - y + mod) % mod;
                w = w * Wn % mod;
            }
        }
    }
    if(type == -1) {
        int iv = powe(M, mod - 2);
        for (int i = 0; i < M; i++) {
        a[i] = a[i] * iv % mod;
        }
    }
};
vector<lld> conv(vector<lld> a, vector<lld> b) { //return a * b
	int siz = a.size() + b.size();
	int M = 1;
	while(M < siz) M <<= 1;
	a.resize(M); b.resize(M); 
    NTT(a, M, 1); NTT(b, M, 1);
    for(int i = 0; i <= M; i++) {
        a[i] = a[i] * b[i] % mod;
    }
    NTT(a, M, -1); lld t = powe(M, mod - 2);
    for(int i = 0; i <= M; i++) {
        a[i] = a[i] * t % mod;
    }
    a.resize(siz); return a; 
}
vector<lld> inv_t(N); 
void polyinv(const vector<lld> &f, const int n, vector<lld> &h) { //h = f^{-1}
    if(n == 1) {
        h[0] = powe(f[0], mod - 2); return ;
    }
    polyinv(f, (n + 1) >> 1, h);  
    int len = 1; while(len < n * 2) len *= 2;
    copy(f.begin(), f.begin() + n, inv_t.begin());
    fill(inv_t.begin() + n, inv_t.begin() + len, 0); 
    // init(len); 
    NTT(inv_t, len, 1); NTT(h, len, 1);
    for(int i = 0; i < len; i++) {
        h[i] = (2ll - inv_t[i] * h[i] % mod + mod) % mod * h[i] % mod;
    }
    NTT(h, len, -1);
    fill(h.begin() + n, h.begin() + len, 0);
}
vector<lld> tmp(N), tmp2(N);
void polysqrt(const vector<lld> &f, int n, vector<lld> &h) { // h = sqrt(f)
    if(n == 1) {
        h[0] = 1; return ;
    }
    polysqrt(f, (n + 1) >> 1, h);
    int len = 1;
    while(len < n * 2) {
        len *= 2;
    }
    fill(tmp.begin(), tmp.begin() + len, 0); polyinv(h, n, tmp); 
    copy(f.begin(), f.begin() + n, tmp2.begin());
    fill(tmp2.begin() + n, tmp2.begin() + len, 0);
    NTT(tmp2, len, 1); NTT(tmp, len, 1); NTT(h, len, 1);
    for(int i = 0; i < len; i++) {
        h[i] = (1ll * inv2 * (h[i] + tmp[i] * tmp2[i] % mod) % mod) % mod;
    }
    NTT(h, len, -1);
    fill(h.begin() + n, h.begin() + len, 0);
}
int n, m;
int main() {
    // freopen("data.in", "r", stdin);
    cin >> n; 
    vector<lld> g(N), ans(N); 
    g.clear(); ans.clear(); inv2 = powe(2, mod - 2);
    for(int i = 0; i < n; i++) {
        cin >> g[i];
    }
    polysqrt(g, n, ans);
    for(int i = 0; i < n; i++) {
        cout << ans[i] << " ";
    }
    return 0;
}

多项式对数,指数,快速幂

P5245

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int N = 400010;
const lld g = 3, mod = 998244353;
lld r[N];
lld powe(lld a, lld b) {
    lld base = 1;
    while(b) {
        if(b & 1) base = base * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return base;
}
lld inv2;
void init(int M) {
    int l = log2(M); r[0] = 0; 
    for(int i = 1; i < M; i++) {
        r[i] = (r[i >> 1] >> 1 | (i & 1) << (l - 1));
    }
}
void NTT(vector <lld> &a, int M, int type) {
    init(M);
	for(int i = 0; i < M; i++) {
        if(i < r[i]) swap(a[i], a[r[i]]);
    }
    lld Wn, w;
    for(int mid = 1; mid < M; mid <<= 1) {
        Wn = powe(g, (mod - 1) / (mid << 1));
        if(type == -1) Wn = powe(Wn, mod - 2);
        int size = mid << 1;
        for(int i = 0; i < M; i += size) {
            int j = i + mid; w = 1;
            for(int k = i; k < j; k++) {
                lld x = a[k], y = w * a[k + mid] % mod;
                a[k] = (x + y) % mod;
                a[k + mid] = (x - y + mod) % mod;
                w = w * Wn % mod;
            }
        }
    }
    if(type == -1) {
        int iv = powe(M, mod - 2);
        for (int i = 0; i < M; i++) {
        a[i] = a[i] * iv % mod;
        }
    }
};
vector<lld> inv_t(N); 
void polyinv(const vector<lld> &f, const int n, vector<lld> &h) { //h = f^{-1}
    if(n == 1) {
        h[0] = powe(f[0], mod - 2); return ;
    }
    polyinv(f, (n + 1) >> 1, h);  
    int len = 1; while(len < n * 2) len *= 2;
    copy(f.begin(), f.begin() + n, inv_t.begin());
    fill(inv_t.begin() + n, inv_t.begin() + len, 0); 
    NTT(inv_t, len, 1); NTT(h, len, 1);
    for(int i = 0; i < len; i++) {
        h[i] = (2ll - inv_t[i] * h[i] % mod + mod) % mod * h[i] % mod;
    }
    NTT(h, len, -1);
    fill(h.begin() + n, h.begin() + len, 0);
}
vector<lld> inv_f(N), rf(N);
void polyln(const vector<lld> &f, const int n, vector<lld> &g) {
    int limit = 1, l = 0;
    while(limit < n * 2) limit <<= 1, l++;
    for(int i = 0; i < limit; i++) {
        inv_f[i] = rf[i] = 0; 
    }
    polyinv(f, n, inv_f); 
    for(int i = 0; i <= n; i++) {
        rf[i] = 1ll * f[i + 1] * (i + 1) % mod;
    }
    NTT(rf, limit, 1); NTT(inv_f, limit, 1);
    for(int i = 0; i < limit; i++) {
        g[i] = 1ll * rf[i] * inv_f[i] % mod;
    }
    NTT(g, limit, -1);
    fill(g.begin() + n + 1, g.begin() + limit, 0);
    for(int i = n; i >= 1; i--) {
        g[i] = g[i - 1] * powe(i, mod - 2) % mod;
    }
    g[0] = 0;
}
vector<lld> ln_f(N), res(N);
void polyexp(const vector<lld> &a, const int len, vector<lld> &b) {
    if(len == 1) {
        b[0] = 1; return ;
    }
    polyexp(a, (len + 1) >> 1, b); polyln(b, len, ln_f);
    int limit = 1, l = 0;
    while(limit < (len * 2)) limit <<= 1, l++;
    fill(res.begin(), res.begin() + limit, 0);
    for(int i = 0; i <= len; i++) {
        res[i] = (a[i] - ln_f[i] + mod) % mod;
    }
    res[0]++;
    NTT(res, limit, 1); NTT(b, limit, 1);
    for(int i = 0; i < limit; i++) {
        b[i] = b[i] * res[i] % mod;
    }
    NTT(b, limit, -1);
}
void polypow(const vector<lld> &f, const int n, lld k, vector<lld> &h) {
    vector<lld> tmp(N);
    polyln(f, n, tmp);
    for(int i = 0; i < n; i++) {
        tmp[i] = tmp[i] * k % mod;
    }
    polyexp(tmp, n, h);
}
int n; lld k = 0;
string kk;
int main() {
    // freopen("data.in", "r", stdin);    
    cin >> n >> kk;
    for(int i = 0; i < kk.size(); i++) {
        k = k * 10 + (kk[i] - '0');
        if(k > mod) k = k % mod;
    }
    vector<lld> a(N);
    for(int i = 0; i < n; i++) {
        cin >> a[i];
    }
    vector<lld> ans(N);
    polypow(a, n, k, ans);
    for(int i = 0; i < n; i++) {
        cout << ans[i] << " ";
    }
    return 0;
}

封装ntt

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int N = 4e6 + 10;
const lld mod = 998244353, g = 3;
lld powe(lld a, lld b) {
    lld base = 1;
    while(b) {
        if(b & 1) base = base * a % mod;
        a = a * a % mod; b >>= 1;
    }
    return base;
}
inline lld Inv(lld x) {
    return powe(x, mod - 2);
}
lld W[N], Cw[N], rev[N];
int deg, lg;
void Init(int len) {
    deg = 1; lg = 0; while(deg < len) deg <<= 1, ++lg;
    W[0] = Cw[0] = 1; 
    lld Wp = powe(g, (mod - 1) / deg), Cp = powe(Inv(g), (mod - 1) / deg);
    for(int i = 1; i < deg; i++) {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
        W[i] = W[i - 1] * Wp % mod; Cw[i] = Cw[i - 1] * Cp % mod;
    }
}

struct Poly {
    vector<lld> f;
    lld &operator[] (const int &i) {return f[i];}
    lld operator[] (const int &i)const {return f[i];}
    int Size() {return f.size();}
    int Size()const {return f.size();}
    void Resize(int len) {f.resize(len);}
    void PopZero() {while(f.size() > 1 && f.back() == 0) f.pop_back();}
    void Reverse() {reverse(f.begin(), f.end());}
    void NTT(int deg, lld ww[], int opt) {
        Resize(deg);
        for(int i = 0; i < deg; i++) {
            if(i < rev[i]) swap(f[i], f[rev[i]]);
        }
        for(int t = deg >> 1, m = 1; m < deg; m <<= 1, t >>= 1) {
            for(int i = 0; i < deg; i += (m << 1)) {
                for(int j = 0; j < m; j++) {
                    int x = f[i + j], y = ww[t * j] * f[i + j + m] % mod;
                    f[i + j] = (x + y) % mod, f[i + j + m] = (x - y + mod) % mod;
                }
            }
        }
        if(opt == -1) {
            int inv = Inv(deg);
            for(int i = 0; i < deg; i++) {f[i] = f[i] * inv % mod;}
        }
    }
	friend Poly operator + (const Poly &x, const Poly &y) {
		Poly res; res.Resize(max(x.Size(), y.Size()));
		for(int i = 0; i < x.Size(); i++) {
            res[i] = x[i];
        }
		for(int i = 0; i < y.Size(); i++) {
            res[i] = (res[i] + y[i]) % mod;
        }
		return res;
	}
	friend Poly operator - (const Poly &x, const Poly &y) {
		Poly res; res.Resize(max(x.Size(), y.Size()));
		for(int i = 0; i < x.Size(); i++) {
            res[i] = x[i];
        }
		for(int i = 0; i < y.Size(); i++) {
            res[i] = (res[i] - y[i] + mod) % mod;
        }
		return res;
	}
	friend Poly operator * (const Poly &x, const Poly &y) {
		Poly res, A = x, B = y; 
        int sA = A.Size(), sB = B.Size(); Init(sA + sB - 1);
		A.NTT(deg, W, 1); B.NTT(deg, W, 1); res.Resize(deg);
		for(int i = 0; i < deg; i++) {
            res[i] = A[i] * B[i] % mod;
        }
		res.NTT(deg, Cw, -1); res.Resize(sA + sB - 1);
		return res;
    }
	Poly pInv() {
		Poly res; res.Resize(1);
		if(f.empty()) return res;
		int len, lim; res[0] = Inv(f[0]); Poly A, B;
		for(len = 2; len < (Size() << 1); len <<= 1) {
			A.f.assign((this->f).begin(), (this->f).begin() + min(len, Size()));
            B = res; B.Resize(min(len, Size()));
			Init(A.Size() + B.Size() - 1); 
            A.NTT(deg, W, 1), B.NTT(deg, W, 1); res.Resize(deg);
			for(int i = 0; i < deg; i++) {
                res[i] = ((2ll - A[i] * B[i] % mod) * B[i] % mod + mod) % mod;
            }
			res.NTT(deg, Cw, -1); res.Resize(len);
		}
		res.Resize(Size()); return res;
	}
	friend Poly operator / (const Poly &x, const Poly &y) {
		Poly E = x, D = y;
		E.Reverse(), D.Reverse(); D.Resize(x.Size() - y.Size() + 1);
		D = D.pInv(); D = D * E;
		D.Resize(x.Size() - y.Size() + 1); D.Reverse();
		return D;
	}
	Poly Dir() {
        Poly res; res.Resize(Size());
        for(int i = 0; i + 1 < f.size(); i++) {
            res[i] = f[i + 1] * (i + 1) % mod; 
        }
        return res;
    }
    void Integ() {
        for(int i = ((int)f.size()) - 1; i >= 1; i--) {
            f[i] = f[i-1] * Inv(i) % mod;
        }
        f[0]=0;
    }
	Poly Ln() {
		Poly res = Dir(); Poly I = pInv();
		res = res * I; res.Integ(); res.Resize(Size());
		return res;
	}
	Poly Exp() {
		Poly res; res.Resize(1); res[0] = 1;
		int len; Poly A, B;
		for(len = 2; len < (Size() << 1); len <<= 1) {
			res.Resize(len); A = res.Ln();
			B.f.assign((this->f).begin(), (this->f).begin() + min(len, Size()));
			B = B - A; B[0]++; res = res * B;
			res.Resize(len);
		}
		res.Resize(Size()); return res;
	}
	Poly Powe(int k) { 
		Poly res = Ln();
		for(int i = 0; i < res.Size(); i++) {
            res[i] = res[i] * k % mod;
        }
		return res.Exp();
	}
	Poly Sqrt() {
		Poly res; res.Resize(1); res[0] = 1;
		int len; Poly A, B; int inv2 = Inv(2);
		for(len = 2; len < (Size() << 1); len <<= 1){
			res.Resize(len); A = res.pInv();
			B.f.assign((this->f).begin(), (this->f).begin() + min(len,Size()));
			A = A * B;
			for(int i = 0; i < len; i++) {
                res[i] = (res[i] + A[i]) * inv2 % mod;
            }
			res.Resize(len);
		}
		res.Resize(Size()); return res;
	}
};

int main() {
    return 0;
}
posted @ 2023-09-16 18:27  Mcggvc  阅读(55)  评论(0编辑  收藏  举报