Luogu4770 NOI2018 你的名字 SAM、主席树
UPD:发现之前被smy误导的一个细节,改过来之后就AC了……
一道比较套路的SAM题,虽然我连套路都不会……
先考虑前\(68pts\),也就是\(l=1 , r=|S|\)的情况。我们对\(S\)建好SAM,把\(T\)扔到\(S\)的SAM上匹配,如果不考虑本质不同子串的性质,那么答案就是\(\sum\limits_{i=1}^{|T|} i - l_i\),其中\(l_i\)是匹配到第\(i\)个字符时的匹配长度。
然后考虑如何去重。对\(T\)也建SAM,把\(T\)也放在\(T\)的SAM上匹配。发现在匹配到第\(i\)个字符时,以\(i\)为右端点、长度为\([1,l_i]\)的串都是不合法的,而这些不合法的串在\(T\)所在的SAM上对应的是\(parent\)树上的一条链。于是在放在\(T\)的SAM上匹配的时候不断跳父亲,直到\(Shortest_u \leq l_i\),然后在\(u\)上打上\(l_i\)的标记。最后在\(parent\)树上递推一遍把标记传一下就可以计算出不合法的串的数量。
然后考虑一般情况。唯一存在的问题是:可能存在某些情况下到达的状态在\(S[l,r]\)中没有出现。这个时候考虑对于\(S\)上的每一个节点维护它的\(endpos\)集合,此时\(S[l,r]\)中不存在当前对应的串\(\Leftrightarrow\)当前状态的\(endpos\)集合与\([l+len-1,r]\)无交,然后\(--len\),如果\(len =\)父亲状态的\(Longest\)就跳父亲。这里暴力修改\(len\)的总次数仍然是\(O(|T|)\)的所以可以接受。
对于\(endpos\)集合的维护,不难知道某一个点有的\(endpos\)它的祖先也会有。所以就变成了一个子树内某区间是否有值的问题,在\(parent\)树上跑出dfn序然后建立主席树即可。总复杂度\(O(\sum |T| log |S|)\)
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<vector>
//This code is written by Itst
using namespace std;
const int MAXN = 1e6 + 7;
namespace Tree{
struct node{
int l , r , sz;
}Tree[MAXN << 4];
int rt[MAXN] , cnt;
#define mid ((l + r) >> 1)
int insert(int x , int l , int r , int tar){
int t = ++cnt;
Tree[t] = Tree[x];
++Tree[t].sz;
if(l == r) return t;
if(mid >= tar)
Tree[t].l = insert(Tree[t].l , l , mid , tar);
else
Tree[t].r = insert(Tree[t].r , mid + 1 , r , tar);
return t;
}
bool query(int x , int y , int l , int r , int L , int R){
if(L > R || Tree[x].sz == Tree[y].sz) return 0;
if(l >= L && r <= R) return 1;
if(mid >= L && query(Tree[x].l , Tree[y].l , l , mid , L , R))
return 1;
return mid < R && query(Tree[x].r , Tree[y].r , mid + 1 , r , L , R);
}
}
using Tree::insert; using Tree::query; using Tree::rt;
struct SAM{
int Lst[MAXN] , Sst[MAXN] , fa[MAXN] , trans[MAXN][26] , endpos[MAXN];
int cnt = 1 , lst = 1 , L;
void insert(int len , int x){
int t = ++cnt , p = lst;
Lst[lst = t] = endpos[t] = len;
while(p && !trans[p][x]){
trans[p][x] = t;
p = fa[p];
}
if(!p) {Sst[t] = fa[t] = 1; return;}
int q = trans[p][x];
Sst[t] = Lst[p] + 2;
if(Lst[q] == Lst[p] + 1) {fa[t] = q; return;}
int k = ++cnt;
memcpy(trans[k] , trans[q] , sizeof(trans[k]));
Lst[k] = Lst[p] + 1; Sst[k] = Sst[q];
Sst[q] = Lst[p] + 2;
fa[k] = fa[q]; fa[q] = fa[t] = k;
while(trans[p][x] == q){
trans[p][x] = k;
p = fa[p];
}
}
}S , T;
char s[MAXN];
int mrk[MAXN] , sz[MAXN] , dfn[MAXN] , ts , LS;
vector < int > ch[MAXN];
void clear(){
memset(T.trans , 0 , sizeof(int) * 26 * (T.cnt + 1));
memset(T.fa , 0 , sizeof(int) * (T.cnt + 1));
memset(T.Lst , 0 , sizeof(int) * (T.cnt + 1));
memset(T.Sst , 0 , sizeof(int) * (T.cnt + 1));
memset(mrk , 0 , sizeof(int) * (T.cnt + 1));
T.lst = T.cnt = 1;
}
void dfs(int x){
dfn[x] = ++ts;
sz[x] = 1;
rt[ts] = rt[ts - 1];
if(S.endpos[x])
rt[ts] = insert(rt[ts] , 1 , LS , S.endpos[x]);
for(int i = 0 ; i < ch[x].size() ; ++i){
dfs(ch[x][i]);
sz[x] += sz[ch[x][i]];
}
}
void init(){
scanf("%s" , s + 1);
LS = strlen(s + 1);
for(int i = 1 ; i <= LS ; ++i)
S.insert(i , s[i] - 'a');
for(int i = 2 ; i <= S.cnt ; ++i)
ch[S.fa[i]].push_back(i);
dfs(1);
}
queue < int > q;
int in[MAXN];
long long ans;
void getans(){
for(int i = 2 ; i <= T.cnt ; ++i)
++in[T.fa[i]];
for(int i = 2 ; i <= T.cnt ; ++i)
if(!in[i]) q.push(i);
while(!q.empty()){
int t = q.front(); q.pop();
if(t == 1) continue;
if(mrk[t]){
ans -= min(mrk[t] , T.Lst[t]) - T.Sst[t] + 1;
mrk[T.fa[t]] = mrk[t];
}
if(!--in[T.fa[t]])
q.push(T.fa[t]);
}
}
int main(){
#ifndef ONLINE_JUDGE
freopen("in" , "r" , stdin);
//freopen("out" , "w" , stdout);
#endif
init();
int Q;
for(scanf("%d" , &Q) ; Q ; --Q){
clear();
int l , r;
scanf("%s %d %d" , s + 1 , &l , &r);
int L = strlen(s + 1) , u = 1 , len = 0 , v = 1;
for(int i = 1 ; i <= L ; ++i)
T.insert(i , s[i] - 'a');
for(int i = 1 ; i <= L ; ++i){
while(u - 1 && !S.trans[u][s[i] - 'a'])
len = S.Lst[u = S.fa[u]];
if(S.trans[u][s[i] - 'a']){
u = S.trans[u][s[i] - 'a'];
++len;
}
while(u != 1 && !query(rt[dfn[u] + sz[u] - 1] , rt[dfn[u] - 1] , 1 , LS , l + len - 1 , r)){
--len;
if(len < S.Sst[u])
u = S.fa[u];
}
v = T.trans[v][s[i] - 'a'];
while(v - 1 && T.Sst[v] > len) v = T.fa[v];
mrk[v] = max(mrk[v] , len);
}
ans = 0;
getans();
for(int i = 2 ; i <= T.cnt ; ++i)
ans += T.Lst[i] - T.Sst[i] + 1;
printf("%lld\n" , ans);
}
return 0;
}