Processing math: 100%

NOI2018 你的名字(SAM + 可持久化线段树合并)

题目链接: https://www.luogu.com.cn/problem/P4770

SAM好题.

(I)首先我们考虑l = 1,r = |S|的情况怎么做

我们要求的是本质不同的子串str的数量,满足str是T的子串,且str不是Sl,r的子串

容易用补集转化成T本质不同的子串数减去S和T本质不同子串数

第一个问题很平凡,我们考虑第二个问题

我们对S,T分别建自动机,令T在S上面跑匹配,同时按着S的跑法在T自己上面跑匹配(因为T的每个子串都为SAMT所接受,所以一定能跑)

对于每个前缀我们都可以求出它和S的最长公共后缀l,及在T上的节点,容易发现这个节点以上的长度<=l的都是本质不同的公共子串,因为可能算重所以先打标记然后Treedp统计(这也是为什么要在T上跑的原因,

因为在S上面跑,每次都要遍历S的parent tree时间复杂度不对)

(II)接下来才是难点,如果l,r任意怎么做

显然对于每个子串都建后缀自动机是不可能的,我们思考我们这个后缀自动机到底干了什么呢?

1.判断有没有tran(p,c)的转移边.

2.判断p这个节点的maxlen和minlen

我们可以发现,只要用线段树合并维护出endpos集合,就可以完成区间的上诉两个问题.

1
2
3
4
5
6
7
8
9
10
11
int u = get(sam[p].ch[c],l + len,r);
if(u){
    len++;
    p = sam[p].ch[c];
    x = sam[x].ch[c];
}
else{
    while(len != -1 && !get(sam[p].ch[c],l + len,r)){
        len--;
        if(len == sam[sam[p].fa].len)   p = sam[p].fa;
    }

  其中get(p,l,r)表示p这个节点的endpos集合在[l,r]范围内的最大值

设正在匹配的最长公共子串为s

我们发现我们原本要做的事情是判断s在p这个节点上能不能添上'c'这个字符,即判断 if(sam[p].ch[c] != 0),但是因为有区间限制我们应判断是否存在一个位置x可以接上s+'c',即在[l,r]区间内,是否存在一个endpos(x)满足x - len(s+'c')  + 1>= l,即x >= l + len(s+'c') - 1也即x >= len(s) + l,于是只要判断[l+len,r]区间内是否存在endpos集合的元素即可

注意我们若失配此时不应该直接跳fa,而应该先让len自减,要记住这个后缀自动机只是一个框架,是S1,n而不是Sl,r的SAM.

有人可能会问:怎么暴力while怎么能过?

因为数据水?  其实这个时间复杂度是正确的,我们考虑势能分析法,容易发现每次while,len最多减少1,外面for循环每次最多增加1,所以单次匹配时间复杂度是O(|T|logn)的

有很多细节,看代码吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
/*NOI2018[你的名字]*/
#include<bits/stdc++.h>
using namespace std;
#define ll long long
int read(){
    char c = getchar();
    int x = 0;
    while(c < '0' || c > '9')     c = getchar();
    while(c >= '0' && c <= '9')       x = x * 10 + c - 48,c = getchar();
    return x;
}
const int N = 2e6 + 10;
struct SegmentTree{
    int lc,rc;
    int mx;
}t[N<<4];/*线段树维护endpos集合*/
int Rt[N],num,n;
void pushup(int p){
    t[p].mx = max(t[t[p].lc].mx,t[t[p].rc].mx);
}
void Insert(int &p,int l,int r,int pos){
    if(!p)  p = ++num;
    if(l == r){
        t[p].mx = max(t[p].mx,pos);
        return;
    }
    int mid = (l + r) >> 1;
    if(pos <= mid)   Insert(t[p].lc,l,mid,pos);
    else    Insert(t[p].rc,mid+1,r,pos);
    pushup(p);
}
int merge(int p,int q,int l,int r){
    if(!p || !q)    return p | q;
    int u = ++num;
    int mid = (l + r) >> 1;
    t[u].lc = merge(t[p].lc,t[q].lc,l,mid);
    t[u].rc = merge(t[p].rc,t[q].rc,mid + 1,r);
    pushup(u);
    return u;
}
int query(int p,int l,int r,int a,int b){
    if(a <= l && b >= r)  return t[p].mx;
    int mid = (l + r) >> 1;
    int ans = 0;
    if(a <= mid) ans = max(ans,query(t[p].lc,l,mid,a,b));
    if(b > mid)      ans = max(ans,query(t[p].rc,mid+1,r,a,b));
    return ans;
}
struct SAM{
    int ch[26],len,fa;
}sam[N<<1];
int lst = 1,cnt = 1;
void ins(int c,int rt){
    int p = lst,np = ++cnt;lst = np;
    sam[np].len = sam[p].len + 1;
    for(; !sam[p].ch[c]; p = sam[p].fa)     sam[p].ch[c] = np;
    if(!p)  sam[np].fa = rt;
    else{
        int q = sam[p].ch[c];
        if(sam[q].len == sam[p].len + 1)    sam[np].fa = q;
        else{
            int nq = ++cnt;
            sam[nq] = sam[q];
            sam[nq].len = sam[p].len + 1;
            sam[np].fa = sam[q].fa = nq;
            for(; sam[p].ch[c] == q; p = sam[p].fa)     sam[p].ch[c] = nq;
        }
    }
}
int head[N<<1];
int f[N<<1],tot;
struct Edge{
    int nxt,point;
}edge[N<<1];
void add_edge(int u,int v){
    edge[++tot].nxt = head[u];
    edge[tot].point = v;
    head[u] = tot;
}
char S[N],T[N];
void dfs(int u){
    for(int i = head[u]; i ; i = edge[i].nxt){
        int v = edge[i].point;
        dfs(v);
        f[u] = max(f[u],f[v]);
    }
    f[u] = min(f[u],sam[u].len);
}
void getpos(int u){
    for(int i = head[u]; i ; i = edge[i].nxt){
        int v = edge[i].point;
        getpos(v);
        Rt[u] = merge(Rt[u],Rt[v],1,n);
    }
}
bool valid(int u,int len){
    return len >= sam[sam[u].fa].len + 1 && len <= sam[u].len;
}
int get(int u,int l,int r){
    if(l > r || !u)  return 0;
    return query(Rt[u],1,n,l,r);
}
int getlen(int u,int l,int r){
    int x = get(u,l,r);
    return min(sam[u].len,x - l + 1);
}
ll work(char *s,int rt,int l,int r){
    int m = strlen(s+1);
    int p = 1,len = 0,x = rt;
    for(int i = rt + 1; i <= cnt; ++i){
        add_edge(sam[i].fa,i);
    }
    for(int i = 1; i <= m; ++i){
        int c = s[i] - 'a';
        int u = get(sam[p].ch[c],l + len,r);
        if(u){
            len++;
            p = sam[p].ch[c];
            x = sam[x].ch[c];
        }
        else{
            while(len != -1 && !get(sam[p].ch[c],l + len,r)){
                len--;
                if(len == sam[sam[p].fa].len)   p = sam[p].fa;
            }
            if(len == -1){
                p = 1;
                len = 0;
                x = rt;
            }
            else{
                len++;
                p = sam[p].ch[c];  
                while((!sam[x].ch[c] || !valid(sam[x].ch[c],len)) && x)     x = sam[x].fa;
                if(!x)  x = rt;
                x = sam[x].ch[c];
            }
        }
//      cout<<i<<' '<<len<<endl;
        f[x] = max(f[x],len);
    }
    dfs(rt);
    ll ans = 0;
    for(int i = rt + 1; i <= cnt; ++i){/*!!!attention*/
        if(f[i] > sam[sam[i].fa].len){
//          assert(f[i] > sam[sam[i].fa].len);
            ans += f[i] - sam[sam[i].fa].len;
        }
    }
    for(int i = rt; i <= cnt; ++i)       f[i] = 0;
    return ans;
}
int main(){
    freopen("name.in","r",stdin);
    freopen("name.out","w",stdout);
    scanf("%s",S+1);
    n = strlen(S+1);
    for(int i = 1; i <= n; ++i){
        ins(S[i]-'a',1);
        Insert(Rt[lst],1,n,i);
    }
    for(int i = 2; i <= cnt; ++i){
        add_edge(sam[i].fa,i);
    }
    getpos(1); 
    int q = read();
    while(q--){
        scanf("%s",T+1);
        int l = read(),r = read();
        int m = strlen(T+1);
        int rt = ++cnt;
        lst = rt;
        for(int i = 1; i <= m; ++i){
            ins(T[i]-'a',rt);
        }
        ll ans = 0;
        for(int i = rt + 1; i <= cnt; ++i){
            ans += sam[i].len - sam[sam[i].fa].len;
        }
        ans -= work(T,rt,l,r);
        printf("%lld\n",ans);
    }
    return 0;
}

  

 

posted @   y_dove  阅读(202)  评论(0编辑  收藏  举报
努力加载评论中...
点击右上角即可分享
微信分享提示