HDU - 7084 Pty loves string
tag:border树,树状数组,二维数点
首先前缀x和后缀y拼起来等价于存在一点p,满足\(s[1,x]=s[p-x+1,p]\and s[p+1,p+y+1]=s[n-y+1,n]\),发现这其实是p和p+1两个位置在正串和反串中的border
border有个性质,border的border可以遍历当前点的所有border(画个大border和小border的图就明白了),所以根据这个性质,我们可以建个border树,分别满足两个条件的p就在正串和反串分别建立的border树对应节点的子树内。
所以现在就变成了两棵树中两个子树分别有多少个点满足x+1=y,对其中一个树特殊处理一下,其实就变成了子树内有多少个节点权值相同,也就转化成了二维数点问题。
因为是在子树中,所以有dfs序连续且嵌套的关系,可以把询问离线下来,直接两次dfs然后用bit统计子树中发生的增量(具体实现是一开始减掉然后后面加上)
如果不这么麻烦的话也可以二维数点的通常方法,离线变成二维前缀和的简单容斥,或者在线上主席树这样。
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
char s[N];
int n, Q, T, nxt[N], cnt, head[N], dfn, in[N], out[N], ans[N];
struct Node {int r, id;};
vector<Node>t[N];
struct edge {int to, nxt;} e[N << 1];
struct BIT {
int c[N];
int lowbit(int x) {return x & -x;}
void clear() {
for(int i = 1; i <= n + 1; ++i) c[i] = 0;
}
void add(int x, int v) {
for(int i = x; i <= n + 1; i += lowbit(i)) c[i] += v;
}
int query(int x) {
int ans = 0;
for(int i = x; i; i -= lowbit(i)) ans += c[i];
return ans;
}
}bit;
void getnxt() {
int j = 0;
for(int i = 2; i <= n; ++i) {
while(j && s[j + 1] != s[i]) j = nxt[j];
if(s[j + 1] == s[i]) ++j;
nxt[i] = j;
}
}
void ins(int u, int v) {
e[++cnt] = (edge) {v, head[u]};
head[u] = cnt;
e[++cnt] = (edge) {u, head[v]};
head[v] = cnt;
}
void dfs(int x, int fa) {
in[x] = ++dfn;
for(int i = head[x]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, x);
}
out[x] = dfn;
}
void dfs2(int x, int fa) {
for(int i = 0; i < (int)t[x].size(); ++i) {
Node tmp = t[x][i];
ans[tmp.id] -= bit.query(out[tmp.r]) - bit.query(in[tmp.r] - 1);
}
bit.add(in[n - x], 1); // 反串中-1对应的dfs序
for(int i = head[x]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs2(v, x);
}
for(int i = 0; i < (int)t[x].size(); ++i) {
Node tmp = t[x][i];
ans[tmp.id] += bit.query(out[tmp.r]) - bit.query(in[tmp.r] - 1);
}
}
void solve() {
memset(head, 0, sizeof(head));
memset(ans, 0, sizeof(ans));
dfn = cnt = 0;
cin >> n >> Q;
bit.clear();
for(int i = 1; i <= n; ++i) t[i].clear();
scanf("%s", s + 1);
reverse(s + 1, s + n + 1);
for(int i = 1; i <= Q; ++i) {
int l, r;
cin >> l >> r;
t[l].push_back({r, i});
}
getnxt();
for(int i = 1; i <= n; ++i) {
ins(nxt[i], i);
}
dfs(0, 0);
reverse(s + 1, s + n + 1);
getnxt();
memset(head, 0, sizeof(head));
cnt = 0;
for(int i = 1; i <= n; ++i) {
ins(nxt[i], i);
}
dfs2(0, 0);
for(int i = 1; i <= Q; ++i) printf("%d\n", ans[i]);
}
int main() {
cin >> T;
while(T--) solve();
return 0;
}