[loj6388] 「THUPC2018」赛艇 / Citing
Description
给你一个\(~n \times m~\)的\(~01~\)矩阵,一个人在这个矩阵中走了\(~k~\)步,每一次都往四联通方向中的一个走一步。给定这个人每一步走的方向,已知这个人经过的每一步都没有经过原矩阵中\(~1~\)的位置。问合法的起点有多少种?保证至少有一组解。\(~1 \leq n, m \leq 1500, ~k \leq 5 \times 10 ^ 6~\).
Solution
不难发现那条路径通过补全\(~0~\)之后其实就是一个\(~01~\)矩阵,其中的\(~1~\)就是原路径。问题变成了把该矩阵放在原矩阵中(严格内含)不产生冲突的方案数,实质上就是或起来全是\(~0~\)的方案数。考虑怎么快速求这个问题。把该矩阵通过补\(~0~\)变成和原矩阵一样大的规模,把两个矩阵都拉成长度为\(~n \times m~\)的序列,倒序一个序列做\(~FFT~\)或\(~NTT~\)在看对应位置上是否为\(~0~\)统计答案即可。至于这样为什么是对的,可以考虑这个对应位置的数代表的东西到底是什么,卷积中\(~ans_i~\)代表下标和为\(~i~\)的各项乘积之和,由于之前做过一个区间反转,所以这个\(~ans_i~\)就代表路径矩阵在原矩阵中起始位置为\(~i~\)时矩阵各项匹配起来的乘积的和,而在只有\(~0, 1~\)的情况下,乘法和或的运算法则一样。所以当\(~ans_i~\)为\(~0~\)时,就代表这个匹配位置是合法的,因为没有任何一个\(~1~\)同位。
Code
#include<bits/stdc++.h>
#define For(i, j, k) for(int i = j; i <= k; ++i)
#define Forr(i, j, k) for(int i = j; i >= k; --i)
using namespace std;
inline int read() {
int x = 0, p = 1; char c = getchar();
for(; !isdigit(c); c = getchar()) if(c == '-') p = -1;
for(; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
return x *= p;
}
inline void File() {
#ifndef ONLINE_JUDGE
freopen("loj6388.in", "r", stdin);
freopen("loj6388.out", "w", stdout);
#endif
}
const int N = 1500 + 10, M = (N * N) << 2, mod = 998244353;
int a[M], b[M], rev[M], powg[M], invg[M], k;
int n, m, cnt1, cnt2, siz, len, bit, c[N << 1][N << 1];
char ss[M];
inline int qpow(int a, int b) {
static int res;
for (res = 1; b; a = 1ll * a * a % mod, b >>= 1)
if (b & 1) res = 1ll * res * a % mod;
return res;
}
inline void NTT(int *a, int flag) {
For(i, 0, siz - 1) if (rev[i] > i) swap(a[rev[i]], a[i]);
for (int i = 2; i <= siz; i <<= 1) {
int wn = flag ? powg[i] : invg[i];
for (int j = 0; j < siz; j += i) {
int w = 1;
for (int k = 0; k < (i >> 1); ++ k, w = 1ll * w * wn % mod) {
int x = a[j + k], y = 1ll * w * a[j + k + (i >> 1)] % mod;
a[j + k] = (x + y) % mod, a[j + k + (i >> 1)] = (x - y + mod) % mod;
}
}
}
if (!flag) {
int g = qpow(siz, mod - 2);
For(i, 0, siz) a[i] = 1ll * a[i] * g % mod;
}
}
int main() {
File();
n = read(), m = read(), k = read();
For(i, 1, n) {
scanf("%s", ss + 1);
For(j, 1, m) a[(i - 1) * m + j - 1] = ss[j] - 48;
}
cnt1 = n * m - 1;
int x2 = n, y2 = m, x0 = n, y0 = m, lx = n, ly = m;
scanf("%s", ss + 1), c[lx][ly] = 1;
For(i, 1, k) {
if (ss[i] == 'w') c[-- lx][ly] = 1;
if (ss[i] == 'a') c[lx][-- ly] = 1;
if (ss[i] == 's') c[++ lx][ly] = 1;
if (ss[i] == 'd') c[lx][++ ly] = 1;
x0 = min(x0, lx), y0 = min(y0, ly);
x2 = max(x2, lx), y2 = max(y2, ly);
}
For(i, x0, x0 + n - 1) For(j, y0, y0 + m - 1) b[cnt1 - (cnt2 ++)] = c[i][j];
-- cnt2;
len = cnt1 + cnt2;
for (siz = 1; siz <= len; siz <<= 1) ++ bit;
For(i, 0, siz - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
int g = qpow(3, mod - 2);
for (int i = 1; i <= siz; i <<= 1) {
invg[i] = qpow(g, (mod - 1) / i);
powg[i] = qpow(3, (mod - 1) / i);
}
NTT(a, 1), NTT(b, 1);
For(i, 0, siz - 1) a[i] = 1ll * a[i] * b[i] % mod;
NTT(a, 0);
int ans = 0;
For(i, 1, n - (x2 - x0)) For(j, 1, m - (y2 - y0))
if (a[cnt1 + (i - 1) * m + j - 1] == 0) ++ ans;
cout << ans << endl;
return 0;
}