O(nlog^2/loglogn)的cdq FFT
论文鸽在群里说了一下这个东西,我也实现了一下,发现效果还不错。
由于这个 exp 的 \(O(n\log n)\) 算法非常的慢,所以我们一般采用 \(O(n\log^2 n)\) 的分治 FFT 来求解。
普通的分治 FFT 已经可以与论文鸽的 \(O(n\log n)\) exp 五五开了,但是有没有更快的方法呢?
注意,这个优化只能在 cdq FFT 的时候采用,也就是说不能优化 n 个一次多项式的卷积之类的问题。
\(O(n\log^2n)\)
我们先来回忆一下普通的分治 FFT 是如何做的。
我们假设 \(F(x) = e^{G(X)}\),这里我们知道 \(G(x)\),我们要求解 \(F(x)\)。
两侧求导,得 \(F^{'}(x) = e^{G(X)} \times G^{'}(X) = F(X) \times G^{'}(X)\)
也就是说我们是对这个式子进行求解:\(F^{'}(x) = F(X) \times G^{'}(X)\)
我们采用 \(solve(l, r)\) 表示求解 \(F(X)\) 的第 \(l\) 项到第 \(r\) 项。
取区间中点 \(mid\)。
先调用 \(solve(l, mid)\) 来求解出前半部分。
再计算左侧对右侧的贡献。
再调用 \(solve(mid + 1, r)\) 来求解出后半部分。
这就是普通的分治 FFT。
\(O(\frac{n\log^2n}{\log \log n})\)
首先分治 FFT 是一个树状结构,我们往往可以尝试增加一层往下的分支数来优化深度。
我们设分支数为 \(B\)。
如果直接分治 FFT,那么需要计算每个儿子对后面儿子的贡献,每次计算需要一个长度为 \(O(n / B)\) 的卷积(\(n\)为目前分治区间长度)。
也就是说时间复杂度 \(T(n) = B \times T(\frac{n}{B}) + B \times n \times \log {\frac{n}{b}}\),大力求解得 \(B=2\) 时最优。
我是不是在玩你。
好的我们继续。
我们真的是每一对儿子都要用一次卷积来计算贡献吗?
我们可以先求出这个儿子的点值,考虑它对后面儿子的贡献,这个儿子的点值就不用重复计算了。
存储前面儿子对它的贡献的时候,你也可以直接存储点值,最后一次 FFT 转换即可,也不用多次计算了。
而卷上的是 \(G^{'}(x)\) 的一个区间,这个区间的点值也可以提前计算。
也就是说我们只需要 \(O(B)\) 次长度为 \(O(n / B)\) 的 FFT 了!
因此计算贡献的部分复杂度变为了 \(O(B^2 \times \frac{n}{B} + B \times \frac{n}{b} \times \log {\frac{n}{b}})\),
即 \(O(Bn + n \times \log {\frac{n}{b}})\)。
时间复杂度 \(T(n) = B \times T(\frac{n}{B}) + Bn + n \times \log {\frac{n}{b}}\)。不错,求解一下。
发现 \(B = O(\log n)\) 的时候最优秀,时间复杂度为\(O(\frac{n\log^2n}{\log \log n})\)。
事实上由于计算贡献的时候非常的****,可以使用 avx2 进行优化,亲测一定的常数优化之后进行 \(4 \times 10^6\) 的 exp 只需要 1.5s。
当然其他类似的 cdq FFT 也可以这样进行优化,祝大家早日吊打 \(O(n\log n)\)。
贴代码:
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
#include<bits/stdc++.h>
#define rep(i, l, r) for(int i = (l), i##end = (r);i <= i##end;++i)
const int maxn = 1 << 19 | 1;
typedef long long ll;
typedef unsigned long long u64, ull;
const int mod = 998244353;
struct istream {
static const int size = 1 << 21;
char buf[size], *vin;
inline istream() {
fread(buf,1,size,stdin);
vin = buf - 1;
}
inline istream& operator >> (int & x) {
for(x = *++vin & 15;isdigit(*++vin);) x = x * 10 + (*vin & 15);
return * this;
}
} cin;
struct ostream {
static const int size = 1 << 21;
char buf[size], *vout;
unsigned map[10000];
inline ostream() {
for(int i = 0;i < 10000;++i) {
int p = i;
map[i] = p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
map[i] = map[i] << 8 | p % 10 + 48, p /= 10;
}
vout = buf + size;
}
inline ~ ostream()
{ fwrite(vout,1,buf + size - vout,stdout); }
inline ostream& operator << (int x) {
for(;x > 10000;x /= 10000) *--(unsigned*&)vout = map[x % 10000];
do *--vout = x % 10 + 48; while(x /= 10);
return * this;
}
inline ostream& operator << (char x) {
*--vout = x;
return * this;
}
} cout;
inline ll pow(ll a,int b,ll ans = 1){ for(;b;b >>= 1, a = a * a % mod) if(b & 1) ans = ans * a % mod; return ans; }
inline ll inverse(ll x){ return pow(x, mod - 2); }
int wn[1 << 13], rev[1 << 14], inv[maxn], lim, invlim;
inline void init_(int n) {
int N = 1; for(;N < n;) N <<= 1;
for(int i = 1;i < N;i <<= 1) {
const int w = pow(3, mod / i / 2); wn[i] = 1;
for(int j = 1;j < i;++j) wn[i + j] = (ll) wn[i + j - 1] * w % mod;
}
for(int i = 1;i <= N;i <<= 1) {
for(int j = 1;j < i;++j) rev[i + j] = rev[i + (j >> 1)] >> 1 | j % 2 * i / 2;
}
}
inline void init(int len) {
lim = len; invlim = mod - (mod - 1) / lim;
}
inline void reduce(int & x) {
x += x >> 31 & mod;
}
static u64 t[1 << 13];
inline void fft(int * a,int type) {
for(int i = 0;i < lim;++i) t[i] = a[rev[i + lim]];
#define trans(i, j, k) \
{ \
const u64 x = wn[i + k] * t[i + j + k] % mod; \
t[i + j + k] = t[j + k] + mod - x, t[j + k] += x; \
}
for(int i = 1;i < lim;i <<= 1) {
if(i == 1) {
for(int j = 0;j < lim;j += 8) {
trans(1, j, 0);
trans(1, j + 2, 0);
trans(1, j + 4, 0);
trans(1, j + 6, 0);
}
} else if(i == 2) {
for(int j = 0;j < lim;j += 8) {
trans(2, j, 0);
trans(2, j, 1);
trans(2, j + 4, 0);
trans(2, j + 4, 1);
}
} else {
for(int j = 0;j < lim;j += i + i) for(int k = 0;k < i;k += 4) {
trans(i, j, k + 0);
trans(i, j, k + 1);
trans(i, j, k + 2);
trans(i, j, k + 3);
}
}
}
if(type == 1) {
for(int i = 0;i < lim;++i) a[i] = t[i] % mod;
}
if(type == 0) {
a[0] = t[0] * invlim % mod;
for(int i = 1;i < lim;++i) a[i] = t[lim - i] * invlim % mod;
}
}
inline void fill(int * a, const int * b, int len) {
memcpy(a, b, len << 2), memset(a + len, 0, lim - len << 2);
}
typedef std::function<int(int, int*)> fc;
struct solver {
static const int C = 128;
static const int B = 64;
int n, N;
int rem[maxn], g[maxn], * MM;
int M[B][(maxn << 1) / B];
u64 g0[maxn << 2];
inline void Init(int len, int * multi) {
MM = multi;
for(n = len, N = 1;N < len;N <<= 1);
for(int mid = (N + N) / B;mid > 1;mid /= B) {
init(mid);
for(int j = 0;j + 1 < B;++j) {
if(j * mid / 2 < n) {
for(int i = 0;i < mid;++i) M[j][mid + i] = MM[i + j * mid / 2];
fft(M[j] + mid, 1);
}
}
}
}
inline void solve(int l, int r, u64 * g0, const fc & xxx) {
if(r - l < C) {
for(int i = l;i < r;++i) {
int j = l;
u64 x = rem[i];
#define T(o) (u64) g[j + o] * MM[i - j - o]
for(;j + 15 < i;j += 16) {
x = (x + T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7) +
T(8) + T(9) + T(10) + T(11) + T(12) + T(13) + T(14) + T(15)) % mod;
}
if(j + 7 < i) x += T(0) + T(1) + T(2) + T(3) + T(4) + T(5) + T(6) + T(7), j += 8;
if(j + 3 < i) x += T(0) + T(1) + T(2) + T(3), j += 4;
if(j + 1 < i) x += T(0) + T(1), j += 2;
if(j < i) x += T(0);
#undef T
rem[i] = x % mod;
g[i] = xxx(i, rem + i);
}
return ;
}
const int DT = (r - l) / B;
if(l) memset(g0, 0, r - l << 4);
int end = 0;
for(;end < B && l + end * DT < n;++end);
for(int i = 0;i < end;++i) {
int L = l + i * DT, R = L + DT;
if(i) {
static int T[maxn];
init(DT + DT);
for(int j = 0;j < lim;++j) T[j] = g0[2 * i * DT + j] % mod;
fft(T, 2);
for(int j = L;j < R;++j) rem[j] = (rem[j] + (ll) invlim * t[lim - j + L - DT]) % mod;
}
solve(L, R, g0 + (r - l << 1), xxx);
if(i != end - 1) {
init(DT + DT);
static int b[maxn];
fill(b, g + L, R - L), fft(b, 1);
for(int j = i + 1;j < end;++j) {
ull * g1 = g0 + lim * j;
if(i == B / 2) {
for(int k = 0;k < lim;++k) {
g1[k] = (g1[k] + (ll) b[k] * M[j - i - 1][lim + k]) % mod;
}
} else {
for(int k = 0;k < lim;++k) {
g1[k] += (ll) b[k] * M[j - i - 1][lim + k];
}
}
}
}
}
}
inline void solve(fc x) { solve(0, N, g0, x); }
};
int n, a[maxn], b[maxn];
int main() {
static solver ln, exp;
cin >> n;
for(int i = 0;i < n;++i) {
cin >> a[i]; if(i) a[i] = mod - a[i];
b[i] = (ll) a[i] * i % mod;
}
inv[1] = 1;
for(int i = 2;i < n;++i) {
inv[i] = ll(mod - mod / i) * inv[mod % i] % mod;
}
init_((n + n) / solver::B + 1);
ln.Init(n, a);
ln.solve([](int pos, int * now) { return reduce(*now -= b[pos + 1]), *now; });
for(int i = 1;i < n;++i) {
b[i] = (ll) ln.g[i - 1] * inv[2] % mod;
}
exp.Init(n, b);
exp.solve([](int pos, int * now) { return int(pos == 0 ? 1 : (ll) *now * inv[pos] % mod); });
for(int i = n - 1;i >= 0;--i) {
cout << ' ' << exp.g[i];
}
}