[coci2015-2016 coii] Palinilap【字符串 哈希】
传送门:http://www.hsin.hr/coci/archive/2015_2016/
进去之后点底下的那个。顺带说一句,题目既不是一个英文单词,也不是克罗地亚单词,估计只是从回文串的英文单词palindrome取出前5个字母,再反向复制一遍,造了一个回文串单词。
官方题解说的有点短小精悍,总之步骤就是这样子的:先统计一下原始串有几个回文子串。注意这里我用的是哈希+二分求的,manacher不知道可不可以,可能会很麻烦,因为后面的步骤是奇偶分开讨论的,这里仅仅考虑回文子串的长度为奇数的情况,偶数情况同理,只有一个地方有一点点小差别,后面我会标注(这个小差别葬送了我至少5个小时的调试时间!)。
先解释一下,后文的a[]数组就是我存取原串的数组。
那么现在只需求出,对于每个位置,这个位置的字母如果换成另外一个,会减少以及增加几个回文子串。先考虑增加几个。开一个数组inc[100005][26],inc[i][j]表示第i个位置的字母换成字母j会新增几个回文子串。该怎么计算呢?假设当前正在计算以位置i为中心,可以向两边扩展的长度,按照哈希+二分的算法,当你的left == right时(left和right表示的是二分的上界以及下界),此时是能扩展的最大长度(在我的程序里面,这个长度是包括了第i个字符的,也就是说以i为中心的最长回文子串的左边界是i - left + 1,右边界是i + right - 1),再扩展一点点就会遇到字符不匹配的情况,因此我们就让这两个字符匹配,对应了题目要求的只能修改一个字符。这两个字符匹配后,就再哈希+二分一次,看一下能再匹配几个字符,假设包括修改的那个字符,一共多匹配了x个,那么就:
inc[i - left][a[i + left] - 'a'] += x; //表示把第(i - left)位置的字符改成第(i + left)字符会新增x个回文子串
inc[i + left][a[i - left] - 'a'] += x; //类似以上
现在就要求对于每个位置,如果该位置替换一个字母会导致减少几个回文子串,开个数组dec[100005]。举个例子,这里有一个回文串:
a b c d e f e d c b a
假设我们把第三位的'c'换掉,随便换成什么,比如说换成g:
a b g d e f e d c b a
那么这就减少了3个以'f'为中心的回文子串。因此对于一个当前位置i,以他为中心的最长扩展长度为left(就是哈希+二分的结果),如果位置j被替换成其它任意一个串:
①,如果i - left + 1 <= j <= i - 1,那么会减少(j - (i - left))个以位置i为中心的回文子串。所以要dec[j] += (j - (i - left)),代表j位置换了字符,会导致多减少了(j - (i - left))个回文子串。
②,如果i + 1 <= j <= i + left - 1,那么会减少((i + left) - j)个以位置i为中心的回文子串。
这里只考虑第①种情况,第②种同理。显然,不可能按照刚刚说的对于每个位置j都加一遍,这样子效率明显应对不过来。所以,对于当前位置i,(i - left)是一个定值,也就是说(j - (i - left))这个值,随着j的增大而增大,更具体的,是对dec数组的某一段加一个等差数列,比如刚刚那个例子,a b g d e f e d c b a, dec[1] += 1, dec[2] += 2, dec[3] += 3。也就是说,现在的任务就是要快速对一个数组(这里是dec[]数组)的某个区间加上一个以1为首项,1为公差的等差数列。类似差分数组,我们先构建两个辅助数组c1[100005], c2[100005]。
对于“dec[1] += 1, dec[2] += 2, dec[3] += 3”这个操作,我们可以先让区间[1, 3]每项各加1,再让区间[2, 3]各加1,再让区间[3, 3]各加1。按照差分数组的思想,应该要这么做:
1 2 3 4 5
c2 +1 -1
c2 +1 -1
c2 +1 -1
然后再对c2数组求一次前缀和,存到dec里,那么就得到了dec[]的所有值。可是这样一来还不如刚刚的直接操作dec数组快,所以还需要借用一下c1数组。观察到c2数组里[1, 3]的位置都加了一次1,而4位置减了3,因此我们构建的c1数组,其前缀和就是c2数组!类比刚刚的思想,应该这么做:
1 2 3 4 5
c1 +1 -1
c1 -3 +3
对c1数组求前缀和,保存在c2里,再对c2数组求前缀和,保存在dec里,便得出了dec的所有的值,具体看代码里的incc与decc函数。
之前说回文子串如果有偶数长度,有点小特殊,这就特殊在,如果是奇数长度,那么中心字符被换成了另外一个字符,并不会影响以这个中心字符为中心的回文子串数量,而如果是偶数长度,中心字符有两个,比如a b c c b a,中心字符就是两个c,此时替换任意一个中心字符,会影响到以这两个中心字符为中心的回文子串数量,千万注意!最后,哈希的mod如果让他自然溢出unsigned long long,会被卡一个点,我与std用的都是1e9+7。
#include <cstdio> #include <cstring> #include <algorithm> const int maxn = 100005; const long long base = 131, mod = 1000000007; int n, mx, mx_id; long long inc[maxn][26], dec[maxn], c1[maxn], c2[maxn], ori_ans, delta_ans; char a[maxn]; long long poww[maxn], hash1[maxn], hash2[maxn]; inline long long get_hash1(int left, int right) { return ((hash1[right] - hash1[left - 1] * poww[right - left + 1]) % mod + mod) % mod; } inline long long get_hash2(int left, int right) { return ((hash2[left] - hash2[right + 1] * poww[right - left + 1]) % mod + mod) % mod; } inline void incc(int left, int right) { ++c1[left]; c1[right + 1] -= (right - left + 2); c1[right + 2] += (right - left + 1); } inline void decc(int left, int right) { ++c1[right + 2]; c1[left] += (right - left + 1); c1[left + 1] -= (right - left + 2); } int main(void) { freopen("palinilap.in", "r", stdin); freopen("palinilap.out", "w", stdout); scanf("%s", a + 1); n = strlen(a + 1); a[0] = '$'; poww[0] = 1ull; for (int i = 1; i <= n; ++i) { poww[i] = poww[i - 1] * base % mod; } for (int i = 1; i <= n; ++i) { hash1[i] = (hash1[i - 1] * base + a[i]) % mod; } for (int i = n; i; --i) { hash2[i] = (hash2[i + 1] * base + a[i]) % mod; } int left, right, mid; for (int i = 1; i < n; ++i) { for (int j = i; j < i + 2; ++j) { left = 0; right = std::min(i, n - j + 1); while (left < right) { mid = (left + right + 1) >> 1; if (get_hash1(i - mid + 1, i) == get_hash2(j, j + mid - 1)) { left = mid; } else { right = mid - 1; } } ori_ans += (long long)left; if (left > 0) { if (i == j) { incc(i - left + 1, i - 1); decc(j + 1, j + left - 1); } else { incc(i - left + 1, i); decc(j, j + left - 1); } } if (i - left <= 0 || j + left > n) { continue; } long long & tem1 = inc[i - left][a[j + left] - 'a']; long long & tem2 = inc[j + left][a[i - left] - 'a']; int ori_left = left; right = std::min(i - left - 1, n - j - left); left = 0; while (left < right) { mid = (left + right + 1) >> 1; if (get_hash1(i - ori_left - mid, i - ori_left - 1) == get_hash2(j + ori_left + 1, j + ori_left + mid)) { left = mid; } else { right = mid - 1; } } tem1 += left + 1; tem2 += left + 1; } } for (int i = 1; i <= n; ++i) { c2[i] = c2[i - 1] + c1[i]; dec[i] = dec[i - 1] + c2[i]; } for (int i = 1; i <= n; ++i) { for (int j = 0; j < 26; ++j) { if (j == a[i] - 'a') { continue; } delta_ans = std::max(delta_ans, inc[i][j] - dec[i]); } } ++ori_ans; printf("%I64d\n", ori_ans + delta_ans); return 0; }