二维 FFT
好像是因为fft是线性变换
所以二维的可以先DFT,再DFT,乘起来,再IDFT,再IDFT
就是先按照行DFT,再按照列DFT,乘起来,然后再IDFT回去
代码硕丑无比
code:
#include<bits/stdc++.h>
#define N 5005
#define mod 998244353
#define G 3
#define int long long
using namespace std;
const double pi = acos(-1.0);
int qpow(int x, int y) {
int ret = 1;
for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
return ret;
}
int rev[N];
void ntt(int *a, int n, int o) { //NTT
int G_inv = qpow(G, mod - 2), n_inv = qpow(n, mod - 2);
for(int i = 0; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);
for(int len = 2; len <= n; len <<= 1) {
int w0 = qpow((o == 1)? G : G_inv, (mod - 1) / len);
for(int j = 0; j < n; j += len) {
int wn = 1;
for(int p = j; p < j + (len >> 1); p ++, wn = wn * w0 % mod) {
int X = a[p], Y = wn * a[p + (len >> 1)] % mod;
a[p] = (X + Y) % mod;
a[p + (len >> 1)] = (X - Y + mod) % mod;
}
}
}
if(o == -1)
for(int i = 0; i < n; i ++) a[i] = a[i] * n_inv % mod;
}
int n, m, a[N][N], b[N][N], al[N][N], bl[N][N];
signed main() {
scanf("%lld", &n);
for(int i = 0; i <= n; i ++)
for(int j = 0; j <= n; j ++) scanf("%lld", &a[i][j]);
for(int i = 0; i <= n; i ++)
for(int j = 0; j <= n; j ++) scanf("%lld", &b[i][j]);
int len = 1;
for(;len <= n + n; len <<= 1);
for(int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (len >> 1));
for(int i = 0; i <= n; i ++) ntt(a[i], len, 1), ntt(b[i], len, 1); //先搞列
m = n;
n = len;
for(int i = 0; i <= n; i ++)
for(int j = 0; j <= n; j ++) al[j][i] = a[i][j], bl[j][i] = b[i][j];//把矩阵转置一下
for(int i = 0; i <= n; i ++) ntt(al[i], len, 1), ntt(bl[i], len, 1);//在搞行
for(int i = 0; i <= n; i ++) {
for(int j = 0; j <= n; j ++) {
al[i][j] = al[i][j] * bl[i][j] % mod;//乘起来
}
ntt(al[i], len, -1);//先按照列IDFT回去
for(int j = 0; j <= n; j ++)
a[j][i] = al[i][j];//转置
}
for(int i = 0; i < n; i ++) ntt(a[i], len, -1);//再按照行IDFT回去
for(int i = 0; i < 2 * m + 1; i ++) {//按题目要求输出即可
int ans = 0;
for(int j = 0; j < n; j ++) ans = ans ^ a[i][j];
printf("%lld ", ans);
} printf("\n");
for(int i = 0; i < 2 * m + 1; i ++) {
int ans = 0;
for(int j = 0; j < n; j ++) ans = ans ^ a[j][i];
printf("%lld ", ans);
} printf("\n");
// for(int i = 0; i <= n + m; i ++) printf("%lld ", a[i]);
return 0;
}
我还是好菜啊