【UOJ #34】多项式乘法
http://uoj.ac/problem/34
看了好长时间的FFT和NTT啊qwq在原根那块磨蹭了好久_(:з」∠)_
首先设答案多项式的长度拓展到2的幂次后为n,我们只要求出一个g(不是原根)满足\(i\in \{1\dots n\},g^i\)互不相同,且\(g^n=1\)。
把这个g当做“FFT里面的主n次单位根”的类似物。
而且\(g^{\frac n2}=-1\),因为\(g^{\frac n2}\)与\(g^n\)不相同且\((g^{\frac n2})^2=g^n=1\),所以\(g^{\frac n2}\)只能是-1。
剩下的只要选一个够大的模数满足答案多项式的所有系数都小于这个模数就可以了。
我选的模数是998244353(\(7×17×2^{23}+1\),一个质数,UOJ模数)。不是所有的模数p都可以,像\(10^9+7\)就不可以,因为此时p-1的因子2的指数不够大。只有p-1的因子2的指数c足够大,\(2^c>n\)时才可以。
这里我写了一个暴力找到了一个g=646。
//998244353
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int p = 998244353;
const int n = 1048576;
bool can[p];
int main() {
bool flag = false;
for (int num = 1; num < p; ++num) {
ll re = num;
flag = true;
for (int i = 1; i <= n; ++i) {
if (can[re]) {flag = false; break;}
can[re] = true;
re = re * num % p;
}
if (!flag || re != num) {
re = num;
for (int i = 1; i <= n; ++i) {
if (can[re]) can[re] = false;
re = re * num % p;
}
continue;
}
printf("%d\n", num);
return 0;
}
/*
freopen("tab.txt", "w", stdout);
int num = 646; ll ret = 1;
for (int i = 1; i <= (n >> 1); ++i) {
ret = ret * num % p;
printf("%d %I64d \n", i, ret);
}
*/
}
求出g后就可以NTT了,不过也需要预处理一些分治实现NNT时(一般是迭代实现,这里也是)n不断除2变小需要用到的不同的“主n次单位根”和“主n次单位根的逆元”。
一开始我对原根(及主n次单位根)的定义比较模糊,没有预处理“主n次单位根”的逆元而直接用负的“主n次单位根”导致逆DNNT出错qwq
NTT有取模果然慢啊,不过没有FFT的复数精度误差。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int p = 998244353;
const int N = 1048576;
const int g = 646;
int rev[N], WN[23], nWN[23], n;
int ipow(int a, int b) {
int ret = 1, w = a;
while (b) {
if (b & 1) ret = 1ll * ret * w % p;
w = 1ll * w * w % p;
b >>= 1;
}
return ret;
}
void DNT(int *a, int *A, int flag) {
for (int i = 0; i < n; ++i) A[rev[i]] = a[i];
int tmp = 1;
for (int m = 2; m <= n; m <<= 1, ++tmp) {
int mid = m >> 1, wn = flag == 1 ? WN[tmp] : nWN[tmp];
for (int i = 0; i < n; i += m) {
int w = 1;
for (int j = 0; j < mid; ++j) {
int t = A[i + j], u = 1ll * A[i + j + mid] * w % p;
A[i + j] = (t + u) % p;
A[i + j + mid] = (t - u + p) % p;
w = 1ll * w * wn % p;
}
}
}
if (flag == -1) {
int ni = ipow(n, p - 2);
for (int i = 0; i < n; ++i)
A[i] = 1ll * A[i] * ni % p;
}
}
int da[N], db[N], dc[N];
void NTT(int *a, int lena, int *b, int lenb, int *ans, int n) {
DNT(a, da, 1); DNT(b, db, 1);
for (int i = 0; i < n; ++i) dc[i] = 1ll * da[i] * db[i] % p;
DNT(dc, ans, -1);
}
void init() {
WN[20] = g; nWN[20] = ipow(g, p - 2);
for (int i = 19; i >= 1; --i) {
WN[i] = 1ll * WN[i + 1] * WN[i + 1] % p;
nWN[i] = ipow(WN[i], p - 2);
}
int num = n, tot = 0, res;
while (num) {++tot; num >>= 1;}
n = 1 << tot;
for (int i = 0; i < n; ++i) {
num = i; res = 0;
for (int j = tot - 1; j >= 0; --j) {
if (num & 1) res |= (1 << j);
num >>= 1;
}
rev[i] = res;
}
}
int lena, lenb, a[N >> 1], b[N >> 1], ans[N];
int main() {
scanf("%d%d", &lena, &lenb); ++lena; ++lenb;
for (int i = 0; i < lena; ++i) scanf("%d", a + i);
for (int i = 0; i < lenb; ++i) scanf("%d", b + i);
n = lena + lenb - 1;
init();
NTT(a, lena, b, lenb, ans, n);
int totlen = lena + lenb - 1;
for (int i = 0; i < totlen; ++i) printf("%d ", ans[i]);
puts("");
return 0;
}
一个板子都写了这么长时间省选是要滚粗吗→_→
NOI 2017 Bless All