Palindrome Mouse(2019年牛客多校第六场C题+回文树+树状数组)
题目链接
题意
问\(s\)串中所有本质不同的回文子串中有多少对回文子串满足\(a\)是\(b\)的子串。
思路
参考代码:传送门
本质不同的回文子串肯定是要用回文树的啦~
在建好回文树后分别对根结点为\(0,1\)的子树进行\(dfs\),处理出以每个结点为根结点的子树的大小\(sz\)(也就是说有多少个回文子串以其为中心)和其\(dfs\)序,回文子串包含除了作为其他回文子串的中心被包含外,还可以不作为中心被包含,而这一部分则需要靠回文树的\(fail\)数组来进行处理。
我们先用\(vector\)存下有多少个结点的\(fail\)数组指向\(i\),然后把这些结点按照其对应的回文串长度进行排序,用树状数组来防止去重,加入这个结点对应的\(dfs\)序没被覆盖,那么就加上这个结点的\(sz\),否则就不加。此处举个例子帮助理解:\(cac,cedcacdec\)的\(fail\)数组都指向了\(c\),但是\(cedcacdec\)是\(cac\)子树中的结点,我们在加\(cac\)的时候已经把\(cedcacdec\)的贡献计算过了,如果再加一次就会重复,因此如果某个结点的\(dfs\)序被前面长度短的结点包含过,那么就不用加进答案中。
代码
#include <set>
#include <map>
#include <deque>
#include <queue>
#include <stack>
#include <cmath>
#include <ctime>
#include <bitset>
#include <cstdio>
#include <string>
#include <vector>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
typedef pair<LL, LL> pLL;
typedef pair<LL, int> pLi;
typedef pair<int, LL> pil;;
typedef pair<int, int> pii;
typedef unsigned long long uLL;
#define lson (rt<<1),L,mid
#define rson (rt<<1|1),mid + 1,R
#define lowbit(x) x&(-x)
#define name2str(name) (#name)
#define bug printf("*********\n")
#define debug(x) cout<<#x"=["<<x<<"]" <<endl
#define FIN freopen("/home/dillonh/CLionProjects/Dillonh/in.txt","r",stdin)
#define IO ios::sync_with_stdio(false),cin.tie(0)
const double eps = 1e-8;
const int mod = 1000000007;
const int maxn = 100000 + 7;
const double pi = acos(-1);
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3fLL;
int _, n, cnt;
char s[maxn];
vector<int> vec[maxn];
int sz[maxn], ls[maxn], rs[maxn], vis[maxn];
struct PAM {
//len数组表示以i为结尾的最长回文子串长度
//tot为结点数,lst为上一个字符加的位置
int N;
int str[maxn];
int ch[maxn][30], fail[maxn], len[maxn], cnt[maxn], tot, lst;
void init() {
for(int i = 0; i <= n + 1; ++i) {
cnt[i] = len[i] = fail[i] = 0;
for(int j = 0; j <= 26; ++j) ch[i][j] = 0;
}
N = lst = 0; tot = 1; fail[0] = fail[1] = 1; len[1] = -1;
}
inline void add(int c) {
int p = lst;
str[++N] = c;
while(str[N - len[p] - 1] != str[N]) p = fail[p];
if(!ch[p][c]) {
int now = ++tot, k = fail[p];
len[now] = len[p] + 2;
while(str[N - len[k] - 1] != str[N]) k = fail[k];
fail[now] = ch[k][c]; ch[p][c] = now;
}
lst = ch[p][c]; cnt[lst]++;
}
inline void solve() {
for(int i = tot; i; i--) {
cnt[fail[i]] += cnt[i];
}
}
}pam;
int tree[maxn];
void add(int x, int val) {
while(x < maxn) {
tree[x] += val;
x += lowbit(x);
}
}
int query(int x) {
int ans = 0;
while(x) {
ans += tree[x];
x -= lowbit(x);
}
return ans;
}
void dfs(int u) {
sz[u] = 1;
ls[u] = ++cnt;
for(int i = 1; i <= 26; ++i) {
if(pam.ch[u][i]) {
dfs(pam.ch[u][i]);
sz[u] += sz[pam.ch[u][i]];
}
}
rs[u] = cnt;
}
int main() {
#ifndef ONLINE_JUDGE
FIN;
#endif
scanf("%d", &_);
for(int __ = 1; __ <= _; ++__) {
scanf("%s", s + 1);
n = strlen(s + 1);
pam.init();
for(int i = 1; i <= n; ++i) pam.add(s[i] - 'a' + 1);
cnt = 0;
dfs(1);
dfs(0);
LL ans = 0;
for(int i = 2; i <= pam.tot; ++i) vec[i].clear();
for(int i = 2; i <= pam.tot; ++i) {
if(pam.fail[i] >= 2) vec[pam.fail[i]].emplace_back(i);
}
for(int i = 2; i <= pam.tot; ++i) {
vec[i].emplace_back(i);
sort(vec[i].begin(), vec[i].end(), [](int x, int y) {return pam.len[x] < pam.len[y];});
LL sum = 0;
for(int j = 0; j < (int)vec[i].size(); ++j) {
int u = vec[i][j];
if(query(ls[u]) == 0) {
sum += sz[u];
add(ls[u], 1);
add(rs[u] + 1, -1);
vis[j] = 1;
}
}
ans += sum - 1;
for(int j = 0; j < (int)vec[i].size(); ++j) {
if(!vis[j]) continue;
int u = vec[i][j];
add(ls[u], -1);
add(rs[u] + 1, 1);
vis[j] = 0;
}
}
printf("Case #%d: %lld\n", __, ans);
}
return 0;
}
版权声明:本文允许转载,转载时请注明原博客链接,谢谢~