洛谷P4191 [CTSC2010]性能优化 题解
题意
给出 \(n\) 次多项式 \(A,B\) 和常数 \(C\),求 \(A\times B^C\) 的系数对 \(n+1\) 取模的结果,其中乘法定义为模 \(n\) 意义下的循环卷积。保证 \(n\) 能被分为若干不超过 \(10\) 的正整数的乘积且 \(n+1\) 是质数。(\(1\le n\le 5\times 10^5,1\le C\le 10^9,\rm 6s,250MB\))
题解
闲话。我们模拟赛考了一道循环卷积的题,我发现我啥也不会。学习了学长的博客之后做了这道题。然后发现我之前对 FFT 的理解太浅了。
首先要知道一点,\(\rm DFT\) 和 \(\rm IDFT\) 的过程实际上是在做循环卷积,循环的长度和 \(\rm DFT\) 时用到的单位根次数相同。考虑:
即 \(f,g\) 的 \(n\) 次循环卷积。用单位根反演搞掉取模:
注意到后面两个求和式分别相当于 \(F(\omega_{n}^d),G(\omega_n^d)\)。注意到求 \(F(\omega_{n}^d)\times G(\omega_n^d)\) 的过程就是 \(\rm DFT\) 转化成点值后点积,得到 \(\operatorname{DFT}(H)\)。之后 \(\operatorname{IDFT}\) 的过程只需要代入 \(\omega_{n}^{-d}\),再除 \(n\) 就好了。
再看回上述过程的实现之一,\(\rm FFT\)。它依赖的是 \(n\) 为 \(2\) 的次幂,且 \(\omega_{\frac{n}{2^k}}\) 存在(如果在模意义下的话),来对多项式进行分治。具体来讲,考虑将原多项式奇偶分治。
从而:
这样分治还不太够,因为还要求代入的单位根次数和多项式次数相同。所以我们再把单位根变一下。
发现 \(k\) 和 \(k+\frac{n}{2}\) 处的值对应的表达式很相似。
从而我们可以将问题规模缩小一半,且可以一次求两个值。而对于 \(\rm IDFT\),相当于代入的单位根取反,最后得到的系数再除 \(n\)。
所以只需要把 \(F\) 的系数从 \(1\) 到 \(n-1\) 翻转一下就能转化成 \(\rm DFT\) 的过程。
平常我们体感 \(\rm FFT\) 只是普通的多项式乘法是因为,我们单位根用的是 \(>\deg F+\deg G\) 的最小的 \(2\) 的次幂。模数比最大可能得到的次数还大,就没有影响了。
回到这道题。这道题实际上要求的是模 \(n+1\) 意义下 \(n\) 次单位根下的 \(\rm DFT\)。我们依然想沿用 \(\rm FFT\) 的分治思路,最大的问题在于不能再每次方便地分治成左右两半了。但题目保证,\(n\) 的所有因子都很小,所以我们考虑每次分治都用 \(n\) 的某个质因子 \(d\) 分治,然后花费 \(\mathcal{O}(d)\) 的时间来把它们组合起来。从而做到 \(\mathcal{O}(n\sum k_ip_i)\) 的时间复杂度。\(k_i,p_i\) 分别表示 \(n\) 的质因子和它的次数。考虑每个质因子会出现 \(k_i\) 次,每次都会造成 \(\mathcal{O}(p_in)\) 的时间复杂度。(这里和 \(\rm FFT\) 不一样,下面说)
具体来讲,我们每层选择 \(n\) 的一个因子 \(m\),因为模数的 \(\varphi\) 值是 \(n\),从而求出原根 \(g\) 后 \(m\) 次单位根 \(g^{\frac{n}{m}}\) 一定存在。然后仿照上述过程分治:
之后组合:
这里就不能再像刚刚一样特化一下 \(k\) 比较大的情况了,不过只需要让指数对单位根的次数取模即可得到需要的值。因为不能一次求出多个值,所以每层需要的计算量之和都是 \(\mathcal{O}(nd)\) 的。
似乎还可以实现成非递归版的,但我不会了。不过我这个递归版的开了 O2 好像还挺快。
#include <cstdio>
#include <vector>
#include <algorithm>
struct IO
{
static const int N = 1 << 22;
char buf[N], pbuf[N], *p1 = buf, *p2 = buf, *pp = pbuf;
#define gc() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, N, stdin), p1 == p2) ? EOF : *p1++)
template <typename T>
void read(T& x)
{
x = 0; char ch; int f = 0;
while ((ch = gc()) < '0' || ch > '9') f |= (ch == '-');
while (x = (x << 1) + (x << 3) + (ch ^ 48), (ch = gc()) >= '0' && ch <= '9') ;
if (f) x = ~x + 1;
}
void putc(char c)
{
if (pp - pbuf == N) fwrite(pbuf, 1, N, stdout), pp = pbuf;
*pp++ = c;
}
template <typename T>
void print(T x)
{
static int st[20]; int tp = 0;
if (x < 0) putc('-'), x = ~x + 1;
do st[++tp] = x % 10, x /= 10; while (x);
while (tp) putc(st[tp--] + '0');
}
~IO() { fwrite(pbuf, pp - pbuf, 1, stdout); }
}io;
const int N = 5e5 + 10; typedef long long ll;
int d[30], gn[N], n, mod, m;
inline int ksm(int a, int b)
{
int ret = 1;
while (b)
{
if (b & 1) ret = (ll)ret * a % mod;
a = (ll)a * a % mod; b >>= 1;
}
return ret;
}
void init()
{
std::vector<int> P;
for (int i = 2, t = n; i <= 7; ++i)
{
if (t % i) continue;
P.push_back(i);
while (t % i == 0) d[++m] = i, t /= i;
}
auto check = [&](int g)
{
int phi = mod - 1;
for (auto p : P) if (ksm(g, phi / p) == 1) return false;
return true;
};
int G = 1; while (!check(G)) ++G;
gn[0] = 1; for (int i = 1; i < N; ++i) gn[i] = (ll)gn[i - 1] * G % mod;
}
struct Poly
{
std::vector<int> a;
int& operator[](const int& id) { return a[id]; }
void setTime(const int& tim) { a.resize(tim + 1); }
int getTime() { return (int)a.size() - 1; }
std::vector<int>::iterator begin() { return a.begin(); }
std::vector<int>::iterator end() { return a.end(); }
}A, B;
void FFT(Poly& F, int dep)
{
if (!F.getTime()) return ;
int n = F.getTime() + 1, m = d[dep];
std::vector<Poly> A; A.resize(m);
for (int i = 0; i < m; ++i) A[i].setTime(n / m - 1);
for (int i = 0; i < n; ++i) A[i % m][(i - i % m) / m] = F[i];
for (int i = 0; i < m; ++i) FFT(A[i], dep + 1);
for (int i = 0, p = ::n / n; i < n; ++i)
{
F[i] = 0;
for (int j = 0, q = 0; j < m; ++j, (q += i) %= n)
(F[i] += (ll)gn[p * q] * A[j][i % (n / m)] % mod) %= mod;
}
}
int main()
{
int C; io.read(n); io.read(C); mod = n + 1; init();
A.setTime(n - 1); B.setTime(n - 1);
for (int i = 0; i < n; ++i) io.read(A[i]);
for (int i = 0; i < n; ++i) io.read(B[i]);
FFT(A, 1); FFT(B, 1);
for (int i = 0; i < n; ++i) A[i] = (ll)A[i] * ksm(B[i], C) % mod;
std::reverse(++A.begin(), A.end()); FFT(A, 1);
for (int i = 0, inv = ksm(n, mod - 2); i < n; ++i)
A[i] = (ll)A[i] * inv % mod, io.print(A[i]), io.putc('\n');
return 0;
}