【题解】 「NOI2018」你的名字 SAM+线段树合并 LOJ2720
你的名字
Legend
Link \(\textrm{to LOJ}\)。
Editorial
\(l=1,r=|S|\)
最暴力的做法是:对于 \(T\) 的每一个本质不同的子串判断它是否在 \(S\) 中出现过。
这其实有一个比较优秀的性质:我们枚举 \(T\) 的每一个前缀 \(T[1:i]\ (1 \le i \le |T|)\),并考察它的每一个后缀 \(T[k:i]\ (1 \le k \le i)\) 是否满足条件。即枚举子串的右端点。
则对于这个前缀 \(T[1:i]\),总存在一个分界点 \(p\) 满足 \(T[k:i]\ (k \ge p)\) 在 \(S\) 中出现,\(T[k:i]\ (1 \le k < p)\) 不在 \(S\) 中出现。
对于一个前缀 \(T[1:i-1]\),我们往最后新增一个字符成为 \(T[1:i]\),也总存在一个分界点 \(q\) 满足 \(T[k:i]\ (k > q)\) 是 \(T[1:i-1]\) 的子串(即之前已经出现过的子串),\(T[k:i]\ (1 \le k \le q)\) 不是\(T[1:i-1]\) 的子串(即之前没出现过的子串)。
综上,对于当前前缀 \(T[1:i]\):
- \(S\) 中出现过的后缀 \(T[k:i]\) 满足 \(p \le k \le i\);
- \(T\) 中第一次出现的后缀 \(T[k:i]\) 满足 \(1 \le k \le q\)。
所以对于所有右端点为 \(i\) 的子串,只有满足左端点 \(1 \le k \le \min(p-1 ,q)\) 的子串可以被算进答案。
显然,\(p\) 可以通过 \(T\) 串在 \(S\) 串的 SAM 上匹配得到。
\(q\) 可以通过对 \(T\) 串直接建立 SAM 得到(即当前节点与其后缀连接的最长长度之差 )。
无限制
区间询问对于 \(q\) 的求解没有影响,故考虑 \(p\) 的求解有何变化。
显然,我们需要求得每一个 \(S\) 串 SAM 节点的 endpos 集合,以便判断这个匹配串是出现在区间内部还是外部。endpos 集合可以通过在后缀连接树上线段树合并得到。
时间复杂度 \(O(|S| \log |S|+ \sum |T| \log |T|)\)。
Code
注意到一个实现上的细节是求解 \(p\) 时的匹配与正常的匹配不太相同,由于一个结点储存的子串长度是一个连续的区间,而此处还存在 endpos 集合的限制,那么在查询的时候我们应当每次将匹配长度减小 \(1\) 而不能直接跳到后缀连接。
#include <bits/stdc++.h>
#define debug(...) ;//fprintf(stderr ,__VA_ARGS__)
#define __FILE(x)\
freopen(#x".in" ,"r" ,stdin);\
freopen(#x".out" ,"w" ,stdout)
#define LL long long
const int MX = 5e5 + 23;
const LL MOD = 998244353;
int lsqs;
struct node{
int l ,r;
node *lch ,*rch;
}*root[MX << 1] ,POOL[MX * 40];
node *newnode(int l ,int r){
node *x = &POOL[++lsqs];
x->l = l ,x->r = r;
x->lch = x->rch = nullptr;
return x;
}
void set(node *x ,int p){
if(x->l == x->r) return ;
int mid = (x->l + x->r) >> 1;
if(p <= mid){
if(x->lch == nullptr) x->lch = newnode(x->l ,mid);
set(x->lch ,p);
}
else{
if(x->rch == nullptr) x->rch = newnode(mid + 1 ,x->r);
set(x->rch ,p);
}
}
node *combine(node *x ,node *y){
if(x == nullptr) return y;
if(y == nullptr) return x;
node *t = newnode(x->l ,x->r);
if(x->l == x->r);
else{
*t = *x;
t->lch = combine(x->lch ,y->lch);
t->rch = combine(x->rch ,y->rch);
}
return t;
}
int query(node *x ,int l ,int r){
if(x == nullptr || x->r < l || x->l > r) return 0;
if(l <= x->l && x->r <= r) return 1;
return query(x->lch ,l ,r) || query(x->rch ,l ,r);
}
int n;
struct SAMA{
SAMA(){tot = las = 1;}
struct __node{
int ch[26] ,len ,link;
__node(){memset(ch ,0 ,sizeof ch) ,len = link = 0;}
}a[MX * 2];
int las ,tot;
void extend(int c ,int id){
int p = las ,cur = las = ++tot;
root[tot] = newnode(1 ,n);
set(root[tot] ,id);
a[cur].len = a[p].len + 1;
for( ; p && !a[p].ch[c] ; p = a[p].link) a[p].ch[c] = cur;
if(!p) return a[cur].link = 1 ,void();
int q = a[p].ch[c];
if(a[p].len + 1 == a[q].len) return a[cur].link = q ,void();
int cl = ++tot;
a[cl] = a[q];
a[cl].len = a[p].len + 1;
a[q].link = a[cur].link = cl;
for( ; p && a[p].ch[c] == q ; p = a[p].link) a[p].ch[c] = cl;
}
int cnt[MX * 2] ,len[MX * 2] ,que[MX * 2];
void build(){
for(int i = 1 ; i <= tot ; ++i) cnt[a[i].len]++;
for(int i = 1 ; i <= tot ; ++i) cnt[i] += cnt[i - 1];
for(int i = 1 ; i <= tot ; ++i) que[cnt[a[i].len]--] = i;
for(int i = tot ; i >= 2 ; --i){
int x = que[i];
root[a[x].link] = combine(root[a[x].link] ,root[x]);
}
}
}S;
struct SAMB{
SAMB(){tot = las = mat = 1 ,matlen = 0;}
void clear(){
memset(a ,0 ,sizeof(__node) * (tot + 1));
las = tot = mat = 1 ,matlen = 0;
}
struct __node{
int ch[26] ,len ,link;
__node(){memset(ch ,0 ,sizeof ch) ,len = link = 0;}
}a[MX * 2];
int las ,tot ,mat ,matlen;
int run(int c ,int l ,int r ,int qwq){
for( ; mat > 1 && !S.a[mat].ch[c] ; ){
mat = S.a[mat].link;
matlen = S.a[mat].len;
}
if(S.a[mat].ch[c]) mat = S.a[mat].ch[c] ,++matlen;
for( ; mat > 1 ; mat = S.a[mat].link ,matlen = S.a[mat].len){
fuckyou:
int L = l + matlen - 1;
// 注意查询范围是 [l + matlen -1 ,r]
// 而不是 [l ,r]
if(L > r || !query(root[mat] ,L ,r)){
--matlen;
if(matlen == S.a[S.a[mat].link].len) continue;
goto fuckyou;
}
else break;
}
int mxr = std::min(a[las].len - a[a[las].link].len ,qwq - matlen);
debug("prefix %d contribute %d\n" ,qwq ,mxr);
return mxr;
}
void extend(int c){
int p = las ,cur = las = ++tot;
a[cur].len = a[p].len + 1;
for( ; p && !a[p].ch[c] ; p = a[p].link) a[p].ch[c] = cur;
if(!p) return a[cur].link = 1 ,void();
int q = a[p].ch[c];
if(a[p].len + 1 == a[q].len) return a[cur].link = q ,void();
int cl = ++tot;
a[cl] = a[q];
a[cl].len = a[p].len + 1;
a[q].link = a[cur].link = cl;
for( ; p && a[p].ch[c] == q ; p = a[p].link) a[p].ch[c] = cl;
}
}T;
char str[MX] ,tmp[MX];
LL query(int l ,int r){
T.clear();
LL ans = 0;
int m = strlen(tmp);
for(int i = 0 ; i < m ; ++i){
T.extend(tmp[i] - 'a');
ans += T.run(tmp[i] - 'a' ,l ,r ,i + 1);
}
return ans;
}
int main(){
scanf("%s" ,str);
n = strlen(str);
for(int i = 0 ; i < n ; ++i) S.extend(str[i] - 'a' ,i + 1);
S.build();
int Q = read();
while(Q--){
scanf("%s" ,tmp);
int l = read() ,r = read();
printf("%lld\n" ,query(l ,r));
}
return 0;
}