题目
解法
考虑如何快速判断行与列的回文:
- 行。最多只有一个字符出现奇数次则可构造出合法解。可以预处理每一行的每个字符出现次数的前缀和,这样单次询问是 \(\mathcal O(26)\) 的;更好的方法是预处理 \(pre_{i,j}\) 表示第 \(i\) 行的前 \(j\) 列的字符状态 —— 这是一个 \(26\) 位二进制数。这样通过异或即可 \(\mathcal O(1)\) 询问。
- 列。相对应的行必须每种字符的出现次数一样。枚举 \([l,r]\) 即列区间,可以将每一行对应列区间的字符出现次数进行哈希,如果这一行是合法的对应 \(s_i\) 就是哈希值,否则是一个与其它任何数字都不相等的数。然后把 \(s\) 丢进 \(\text{manacher}\) 里跑一遍即可。
时间复杂度 \(\mathcal O(nm^2)\)。
代码
为了保险,实现的时候使用了双模数哈希。
#include <cstdio>
#include <cstdlib>
#include <iostream>
using namespace std;
typedef long long ll;
typedef pair <ll, ll> Pair;
const int N = 255;
const ll mod1 = 998244353, mod2 = 19260817, base = 163;
Pair h[N][N], s[N << 1];
int n, m, ans, cnt[N][N][27], pre[N][N], len, p[N << 1];
char ch[N][N];
int read() {
int x = 0, f = 1; char s;
while((s = getchar()) > '9' || s < '0') {
if(s == '-') f = -1;
if(s == EOF) exit(0);
}
while(s <= '9' && s >= '0') {
x = (x << 1) + (x << 3) + (s ^ 48);
s = getchar();
}
return x * f;
}
void manacher() {
len = (n << 1);
int R = 0, mid;
for(int i = 1; i <= len; ++ i) {
if(s[i].first < 0) continue;
if(i < R) p[i] = min(p[(mid << 1) - i], R - i);
else p[i] = 1;
while(i - p[i] > 0 && s[i - p[i]] == s[i + p[i]]) ++ p[i];
if(R < i + p[i]) {
R = i + p[i];
mid = i;
}
ans += (p[i] >> 1);
}
}
Pair Hash(const int i, const int j) {
Pair ret; ret.first = ret.second = 0;
for(int k = 0; k < 26; ++ k) {
ret.first = (ret.first * base % mod1 + cnt[i][j][k]) % mod1;
ret.second = (ret.second * base % mod2 + cnt[i][j][k]) % mod2;
}
return ret;
}
bool check(const int i, const int l, const int r) {
int x = pre[i][l - 1] ^ pre[i][r];
return x == 0 || (!(x - (x & -x))); // (x & -x) 是位最低的 1 的数,相减等于 0 就是只有一个 1
}
int main() {
n = read(), m = read();
for(int i = 1; i <= n; ++ i) {
scanf("%s", ch[i] + 1);
for(int j = 1; j <= m; ++ j) {
for(int k = 0; k < 26; ++ k) cnt[i][j][k] += cnt[i][j - 1][k];
++ cnt[i][j][ch[i][j] - 'a'];
h[i][j] = Hash(i, j);
pre[i][j] = pre[i][j - 1] ^ (1 << ch[i][j] - 'a');
}
}
for(int i = 1; i <= m; ++ i)
for(int j = i; j <= m; ++ j) {
for(int k = 1; k <= n; ++ k) {
s[(k << 1) - 1] = make_pair(0, 0);
if(check(k, i, j)) s[k << 1] = make_pair((h[k][j].first - h[k][i - 1].first + mod1) % mod1, (h[k][j].second - h[k][i - 1].second + mod2) % mod2);
else s[k << 1] = make_pair(-1ll * k, -11ll * k); // 保证不重,不然回文判断会出问题
}
manacher();
}
printf("%d\n", ans);
return 0;
}