CF1037H Security(SAM+线段树合并)
题目链接
https://codeforces.com/contest/1037/problem/H
题意
给出一个字符串\(S\)
给出\(Q\)个操作,给出\(L,R,T\),求字典树最小的\(S1\),使得\(S1\)为\(S[L..R]\)的子串,且\(S1\)的字典树严格大于\(T\)。输出这个\(S1\),如果无解输出\(-1\)。
思路
字典序尽量小又要严格大于\(T\),则贪心让前面最多的字符和\(T\)相同,在后面再补一个比\(T\)大且最小的字符。
则最优的方案一定是在\(T\)后面补一个尽可能小的字符。如果补不了,就倒着枚举位置,如果当前位置\(i\)能替换为一个比\(Si\)大的字符,找到最小的可替换字符\(c\)换掉它。答案就是\(T(1∼i-1)+'c'\)。
至于怎么判断一个节点是否为\([L,R]\)子串的节点,用线段树合并获取每个节点\(endpos\)的所有元素。假设当前走到的长度为\(i\),\(endpos\)中存在\(pos\)满足\(pos\in[L,R]\)且\(pos-i+1\in[L,R]\)的节点为\([L,R]\)子串的节点。合起来就是\(pos\in[L+i-1,R]\)。
#include<bits/stdc++.h>
using namespace std;
const int maxx = 2*1e5+10;
char s[maxx],t[maxx];
int last=1,tot=1,fa[maxx],ch[maxx][26],len[maxx];
int sum[50*maxx],ls[50*maxx],rs[50*maxx],rt[maxx],cnt;
int head[maxx],to[maxx],ne[maxx],num;
int ans[maxx];
int n;
void add(int x)
{
int pre=last,now=last=++tot;
len[now]=len[pre]+1;
for(;pre&&!ch[pre][x];pre=fa[pre])ch[pre][x]=now;
if(!pre)fa[now]=1;
else
{
int q=ch[pre][x];
if(len[q]==len[pre]+1)fa[now]=q;
else
{
int nows=++tot;
len[nows]=len[pre]+1;
memcpy(ch[nows],ch[q],sizeof(ch[q]));
fa[nows]=fa[q];fa[q]=fa[now]=nows;
for(;pre&&ch[pre][x]==q;pre=fa[pre])ch[pre][x]=nows;
}
}
}
void addm(int u,int v)
{
to[++num]=v,ne[num]=head[u],head[u]=num;
}
void update(int &u,int l,int r,int x)
{
if(!u)u=++cnt;
if(l==r)
{
sum[u]=1;
return;
}
int mid=(l+r)/2;
if(x<=mid)update(ls[u],l,mid,x);
else update(rs[u],mid+1,r,x);
sum[u]=sum[ls[u]]+sum[rs[u]];
}
int query(int u,int l,int r,int p,int q)
{
if(p>q)return 0;
if(!u)return 0;
if(p<=l&&r<=q)return sum[u];
int mid=(l+r)/2;
int ans=0;
if(p<=mid)ans+=query(ls[u],l,mid,p,q);
if(q>mid)ans+=query(rs[u],mid+1,r,p,q);
return ans;
}
int merge(int a,int b,int l,int r)
{
if(!a)return b;
if(!b)return a;
int u=++cnt;
if(l==r)
{
sum[u]=sum[a]|sum[b];
return u;
}
int mid=(l+r)/2;
ls[u]=merge(ls[a],ls[b],l,mid);
rs[u]=merge(rs[a],rs[b],mid+1,r);
sum[u]=sum[ls[u]]+sum[rs[u]];
return u;
}
void dfs(int u)
{
for(int i=head[u];i;i=ne[i])
{
dfs(to[i]);
rt[u]=merge(rt[u],rt[to[i]],1,n);
}
}
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
for(int i=1;i<=n;i++)
{
update(rt[tot+1],1,n,i);
add(s[i]-'a');
}
for(int i=2;i<=tot;i++)addm(fa[i],i);
dfs(1);
int q,l,r;
scanf("%d",&q);
while(q--)
{
scanf("%d%d%s",&l,&r,t+1);
int m=strlen(t+1);
int u=1,i;
for(i=1;i<=m+1;i++)
{
ans[i]=-1;
int st=(i==m+1?0:t[i]-'a'+1);
for(int j=st;j<26;j++)
{
int v=ch[u][j];
if(v&&query(rt[v],1,n,l+i-1,r))
{
ans[i]=j;
break;
}
}
if(i==m+1)break;
int v=ch[u][t[i]-'a'];
if(v&&query(rt[v],1,n,l+i-1,r))u=v;
else break;
}
while(i&&ans[i]==-1)i--;
if(!i)printf("-1\n");
else
{
for(int j=1;j<i;j++)printf("%c",t[j]);
printf("%c\n",ans[i]+'a');
}
}
return 0;
}