洛谷P4191 [CTSC2010]性能优化(再探FFT)
洛谷P4191 [CTSC2010]性能优化(再探FFT)
题目大意
求循环卷积快速幂,保证 \(n\) 的所有质因子小于 10,且 \(n+1\) 是质数
解题思路
做这题要更深入的理解 FFT
其实正常我们做的 FFT 就是循环卷积,只不过是因为 \(\bmod 2^n\),我们把模数故意比项数略高一些导致没有循环的效果
一般来讲我们有
\[f * g = \sum_{r=0,i,j}[i + j \bmod k = r]f_i *g_j
\]
其中 k 是单位根的底数
那么在这题里我们就直接把项数设成 n 做循环卷积即可
具体来说,我们要先将每个 \(w_n^i\) 带进去求点值,然后直接快速幂最后再代入负单位根即可
代入的过程中要进行奇怪的分治,比如分成 3 项等
\[A(w_n^i) = \sum_{j=0}^{k-1}w_n^{ij}A(w_{\frac nk}^{i})
\]
代码
#include <queue>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MP make_pair
#define ll long long
#define fi first
#define se second
using namespace std;
template <typename T>
void read(T &x) {
x = 0; bool f = 0;
char c = getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
for (;isdigit(c);c=getchar()) x=x*10+(c^48);
if (f) x=-x;
}
template<typename F>
inline void write(F x, char ed = '\n')
{
static short st[30];short tp=0;
if(x<0) putchar('-'),x=-x;
do st[++tp]=x%10,x/=10; while(x);
while(tp) putchar('0'|st[tp--]);
putchar(ed);
}
template <typename T>
inline void Mx(T &x, T y) { x < y && (x = y); }
template <typename T>
inline void Mn(T &x, T y) { x > y && (x = y); }
int P;
ll fpw(ll x, ll mi) {
ll res = 1;
for (; mi; mi >>= 1, x = x * x % P)
if (mi & 1) res = res * x % P;
return res;
}
const int N = 600500;
int prime[] = {0, 2, 3, 5, 7};
int st[N], tot, G, n, C, cnt;
int getG(int n) {
for (int i = 2;i <= n; i++) {
bool fl = 1;
if ((n % 2 == 0) && (fpw(i, n / 2) == 1)) fl = 0;
if ((n % 3 == 0) && (fpw(i, n / 3) == 1)) fl = 0;
if ((n % 5 == 0) && (fpw(i, n / 5) == 1)) fl = 0;
if ((n % 7 == 0) && (fpw(i, n / 7) == 1)) fl = 0;
if (fl) return i;
}
return 0;
}
void NTT(ll *f, int deg, int d) {
if (deg == 1) return; int t = deg / st[d];
ll g[st[d]][t];
for (int i = 0;i < deg; i++) g[i % st[d]][i / st[d]] = f[i];
for (int i = 0;i < st[d]; i++) NTT(g[i], t, d + 1);
ll wn = fpw(G, n / deg), ww = 1;
for (int i = 0;i < deg; ww = ww * wn % P, i++) {
f[i] = 0; int k = i % t; ll w = 1;
for (int j = 0;j < st[d]; j++, w = w * ww % P)
f[i] = (f[i] + w * g[j][k]) % P;
}
}
ll a[N], b[N];
int main() {
read(n), read(C); int x = n; P = n + 1;
for (int i = 0;i < n; i++) read(a[i]);
for (int j = 0;j < n; j++) read(b[j]);
for (int i = 1;i <= 4; i++)
while (x % prime[i] == 0) x /= prime[i], st[++cnt] = prime[i];
G = getG(n); NTT(a, n, 1), NTT(b, n, 1);
while (C) {
if (C & 1) {
for (int i = 0;i < n; i++) a[i] = a[i]* b[i] % P;
}
for (int i = 0;i < n; i++) b[i] = b[i] * b[i] % P;
C >>= 1;
}
NTT(a, n, 1), reverse(a + 1, a + n);
for (int i = 0;i < n; i++) write(a[i] * n % P);
return 0;
}