Problem M. Mediocre String Problem
Description
Given two strings s and t, count the number of tuples (i,j,k) such that
- 1 ≤ i ≤ j ≤ |s|
- 1 ≤ k ≤ |t|.
- j − i + 1 > k.
- The i-th character of s to the j-th character of s, concatenated with the first character of t to the
k-th character of t, is a palindrome.
A palindrome is a string which reads the same backward as forward, such as “abcba” or “xyzzyx”.
Exmple
standard input | standard output |
---|---|
ababa aba |
5 |
aabbaa aabb |
7 |
思路
由于s[i:j]长度大于t[1:k],s[i:j]+t[1:k]又是回文串,所以s[i:j]的前缀必是t[k:1],然后剩下的部分就是回文串。
所以考虑先将s逆序一下,然后扩展kmp求t对于s的extend数组。
因为s逆序,原本s[i:j]前缀是t[k:1]的变成s[i:j]的后缀是t[1:k]。除了后缀t[1:k]部分前面是回文串。所以最终就是对每个i,求右边界与i相连接的回文区间数,再乘上extent[i](即不同长度的t[1:k]数),求和就是答案。
对使用马拉车求出的每个最长回文区间,它对它右半边界每个位置的相连的右边界数的贡献为1。区间加就树状数组就好了。
感觉从处理后的串求出的回文区间转化原串对应的回文区间很容易搞错呢。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 10;
typedef long long ll;
char ss[N];
char t[N];
string s;
int len[N];
ll tarr[N];
ll arr[N];
int nt[N];
int ext[N];
int n, m;
int lowbit(int x) {
return x & -x;
}
void addp(int p, ll v) {
while(p <= n) {
tarr[p] += v;
p += lowbit(p);
}
}
void add(int l, int r, int v) {
addp(l, v);
addp(r + 1, -v);
}
ll getv(int p) {
ll res = 0;
while(p) {
res += tarr[p];
p -= lowbit(p);
}
return res;
}
void gn() { //求next数组
int p = 0, a = 0;
nt[a] = m;
for(int i = 1; i < m; i++) {
if(i >= p || i + nt[i - a] >= p) {
if(i >= p) p = i;
while(p < m && t[p] == t[p - i]) p++;
a = i;
nt[i] = p - i;
} else nt[i] = nt[i - a];
}
}
void getex() {
int p = 0, a = 0;
for(int i = 0; ss[i]; i++) {
if(i >= p || i + nt[i - a] >= p) {
if(i >= p) p = i;
while(p < n && p - i < m && ss[p] == t[p - i]) p++;
a = i;
ext[i] = p - i;
} else ext[i] = nt[i - a];
}
}
void mana() {
s.clear();
s.push_back('$');
s.push_back('#');
for(int i = 0; ss[i]; i++) {
s.push_back(ss[i]);
s.push_back('#');
}
int a = 1, p = 1;
for(int i = 1; i < s.size(); i++) {
if(i < p && len[2 * a - i] < p - i) {
len[i] = len[2 * a - i];
}
else {
int l = p - i;
while(i + l < s.size() && s[i + l] == s[i - l]) l++;
len[i] = l;
a = i;
p = i + l;
}
if(len[i] == 1)continue;
int r = (i + len[i] - 2) / 2; //l, r为对应到原字符串上的回文右半区间
int l = (i + 1) / 2; //i+1为了防止回文区间长度为偶数时l小了
add(l ,r , 1);
}
}
int main() {
ios::sync_with_stdio(false);
cin >> ss;
cin >> t;
n = strlen(ss);
m = strlen(t);
reverse(ss, ss + n);
gn();
getex();
mana();
for(int i = 1; i <= n; i++) {
arr[i] = getv(i);
}
ll ans = 0;
for(int i = 0; i < n; i++) {
ans += 1ll * ext[i] * arr[i];
}
cout << ans <<endl;
}