【题解】CF17E - Palisection
题目大意
给定一个长度为 \(n\) 的小写字母串,试求出该串中相交的回文子串对数(包含也算相交),顺序不同、子串位置完全相同的一对回文子串视为同一对回文子串,结果对 \(51123987\) 取模。
\(1 \leq n \leq 2 \times 10^6\)
解题思路
首先发现题目和 回文 有关,优先考虑 \(manacher\) 算法。考虑按照题目的思路维护,发现直接维护相交的回文子串数量似乎并不可做。因此我们考虑 正难则反,相交的回文子串数量 \(=\) 所有回文子串对数 \(-\) 不相交的回文子串对数。考虑分开求解这两个值,设 $x = $ 该串的回文子串数量,显然本质不同的回文子串对共有 \(\frac{x \times (x - 1)}{2}\) 对。对于任意中点 \(i\),设以 \(i\) 为中点的最长回文子串半径为 \(len_i\)(这里是 \(manacher\) 算法中求出的,所以 \(s_{i - len_i} \neq s_{i + len_i}, s_{i - len_i + 1} = s_{i + len_i - 1}\)),\(\forall 1 \leq j \leq len_i\),存在 \([s_{i - len_i + 1}, s_{i + len_i - 1}]\) 一定是回文子串。因为\(len_i\) 是根据处理后的字符串求出的,所以中点 \(i\) 实际的最长回文子串半径 \(=\) 其贡献的回文子串数量 \(=\) \(\lfloor \frac{len_i}{2} \rfloor\)。难点在于求出不相交的回文子串对数。
我们可以试着手玩一下不相交的回文子串对,发现设左端的回文子串终点为 \(x\),右端的回文子串起点为 \(y\),则一定有 \(x < y\)。因此我们可以考虑枚举左端的回文子串的终点,对于以 \(i\) 为左端回文子串终点时,一个很显然的性质是终点 \(i\) 的贡献为 \(i\) 左侧(包含 \(i\))的回文子串数量 \(\times\) 以 \(i + 1\) 开头的回文子串数量。这样维护可以避免一个字符串被重复维护多次。考虑在 \(manacher\) 算法的过程中维护上面的信息。
对于 \(i\) 左侧的回文子串数量,我们可以通过统计 \(i\) 左侧的回文子串右端点的数量来确定。不妨设 \(l_i\) 为以 \(i\) 为左端点的回文子串数量,\(r_i\) 为以 \(i\) 为右端点的回文子串数量。那么 \(i\) 左侧的回文子串数量为 \(\sum\limits_{j = 1}^{i - 1} r_j\),以 \(i + 1\) 为开头的回文子串数量为 \(l_{i + 1}\)。最终的答案为 \(\sum\limits_{i = 1}^{n - 1} (\sum\limits_{j = 1}^{i} r_j) \times l_{i + 1}\)。
我们可以发现,当一个回文串的两端各删减一个字符后产生的字符串依然是回文串。因此,\(\forall 1 \leq j \leq len_i\),一定存在一个半径为 \(j\)(包含中点)的回文子串给位置 \(i - j + 1\) 贡献一个回文子串的右端点 \(i + j - 1\)。换言之,根据上文提到的性质,在半径范围内一定存在若干个回文子串,设它们的半径为 \(j\),\(\forall 1 \leq j \leq len_i\),这些回文子串的起点和终点分别为 \(i - j + 1\) 和 \(i + j - 1\)。这样在 \([i - j + 1, i]\) 内,每个位置我们都可以贡献一个回文子串的右端点;在 \([i, i + j - 1]\) 内,每个位置我们都可以贡献一个回文子串的左端点。每次求出 \(len_i\) 后给 \([l_{i - len_i + 1}, l_i] + 1, [r_i, r_{i + len_i - 1}] + 1\) 即可。
因为题目只需要先修改多次 \(l\) 和 \(r\),最后仅查询一次。因此我们可以考虑使用 差分 来维护,详见代码。注意最后 \(r\) 数组还需要再套一遍前缀和。另外,这道题可能会出现有符号整数溢出,请注意 取模 问题。
参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 2e6 + 5;
const int maxm = 6e6 + 5;
const int mod = 51123987;
int n, cnt;
int len[maxm], s[maxm], t[maxm];
char a[maxn], str[maxm];
long long ans, tot;
int main()
{
int mid, r;
mid = r = 0;
scanf("%d", &n);
scanf("%s", a);
str[0] = '#';
str[++cnt] = '|';
for (int i = 0; i < n; i++)
{
str[++cnt] = a[i];
str[++cnt] = '|';
}
for (int i = 1; i <= cnt; i++)
{
if (i <= r)
len[i] = min(len[2 * mid - i], r - i + 1);
while (str[i - len[i]] == str[i + len[i]])
len[i]++;
ans = (ans + len[i] / 2 % mod) % mod;
if (i + len[i] - 1 >= r)
{
mid = i;
r = i + len[i] - 1;
}
s[i - len[i] + 1]++, s[i + 1]--;
t[i]++, t[i + len[i]]--;
}
for (int i = 1; i <= cnt; i++)
{
s[i] = (s[i] + s[i - 1]) % mod;
t[i] = (t[i] + t[i - 1]) % mod;
if (str[i] != '|')
s[i / 2] = s[i], t[i / 2] = t[i];
}
for (int i = 2; i <= n; i++)
{
t[i] = (t[i - 1] + t[i]) % mod;
tot = (tot + (1LL * s[i] * t[i - 1]) % mod) % mod;
}
if (ans & 1)
ans = (ans - 1) / 2 % mod * ans % mod;
else
ans = ans / 2 % mod * (ans - 1) % mod;
ans = (ans - tot + mod) % mod;
printf("%lld\n", ans);
return 0;
}