dtoj4542. 「TJOI / HEOI2016」字符串
4542. 「TJOI / HEOI2016」字符串
佳媛姐姐过生日的时候,她的小伙伴从某东上买了一个生日礼物。生日礼物放在一个神奇的箱子中。箱子外边写了一个长为 $ n $ 的字符串 $ s $,和 $ m $ 个问题。佳媛姐姐必须正确回答这 $ m $ 个问题,才能打开箱子拿到礼物,升职加薪,出任 CEO,嫁给高富帅,走上人生巅峰。每个问题均有 $a, b, c, d$ 四个参数,问你子串
$s[a \ldots b]$ 的所有子串和 $s[c \ldots d]$ 的最长公共前缀的长度的最大值是多少?佳媛姐姐并不擅长做这样的问题,所以她向你求助,你该如何帮助她呢?
Sol
对S串建SAM
答案有单调性,考虑二分答案
那么问题转化为一个判断一个串是否在[a,b]之间出现过。
那么相当于判断sam一个点的right集合是否在某范围内。
那么用线段树合并维护right即可
#include<cstdio> #include<iostream> #include<cstdlib> #include<cstring> #include<algorithm> #include<cmath> #define maxn 1000005 #define mid ((l+r)>>1) using namespace std; int n,q,R,cnt,la,rt[maxn],tax[maxn],ord[maxn],tot; int dy[maxn]; char ch[maxn]; struct sam{ int par,Max,pl,nex[26]; }s[maxn]; struct node{ int ls,rs,v; }tr[maxn*30]; void ins(int c){ int np=++cnt,p=la;la=np;s[np].Max=s[p].Max+1; s[np].pl=s[np].Max;dy[s[np].Max]=np; for(;p&&!s[p].nex[c];p=s[p].par)s[p].nex[c]=np; if(!p)s[np].par=R; else { int q=s[p].nex[c],nq; if(s[q].Max==s[p].Max+1)s[np].par=q; else { nq=++cnt,s[nq].Max=s[p].Max+1; for(int j=0;j<26;j++)s[nq].nex[j]=s[q].nex[j]; s[nq].par=s[q].par;s[q].par=s[np].par=nq; for(;p&&s[p].nex[c]==q;p=s[p].par)s[p].nex[c]=nq; } } } void Sort(){ for(int i=1;i<=cnt;i++)tax[s[i].Max]++; for(int i=1;i<=n;i++)tax[i]+=tax[i-1]; for(int i=1;i<=cnt;i++)ord[tax[s[i].Max]--]=i; } void wh(int k){ tr[k].v=tr[tr[k].ls].v+tr[tr[k].rs].v; } void add(int &k,int l,int r,int pl){ if(!k)k=++tot; if(l==r){tr[k].v++;return;} if(pl<=mid)add(tr[k].ls,l,mid,pl); else add(tr[k].rs,mid+1,r,pl); wh(k); } int merge(int x,int y){ if(!x||!y)return x+y; int nx=++tot; tr[nx].ls=merge(tr[x].ls,tr[y].ls); tr[nx].rs=merge(tr[x].rs,tr[y].rs); tr[nx].v=tr[x].v+tr[y].v; return nx; } int ask(int k,int l,int r,int li,int ri){ if(!k)return 0; if(l>=li&&r<=ri)return tr[k].v; int Sum=0; if(li<=mid)Sum+=ask(tr[k].ls,l,mid,li,ri); if(ri>mid)Sum+=ask(tr[k].rs,mid+1,r,li,ri); return Sum; } bool check(int a,int b,int c,int d){ int p=dy[d]; while(s[s[p].par].Max>=d-c+1)p=s[p].par; if(ask(rt[p],1,n,a,b)>0)return 1; return 0; } int main(){ cin>>n>>q; scanf("%s",ch+1); R=cnt=la=1; for(int i=1;i<=n;i++)ins(ch[i]-'a'); Sort(); for(int i=cnt;i>1;i--){ int x=ord[i]; if(s[x].pl)add(rt[x],1,n,s[x].pl); } for(int i=cnt;i>1;i--){ int x=ord[i]; rt[s[x].par]=merge(rt[s[x].par],rt[x]); } for(int i=1,a,b,c,d;i<=q;i++){ scanf("%d%d%d%d",&a,&b,&c,&d); int l=1,r=min(d-c+1,b-a+1); while(l<r){ int M=(l+r+1)/2; if(check(a+M-1,b,c,c+M-1))l=M; else r=M-1; } if(check(a+l-1,b,c,c+l-1))printf("%d\n",l); else puts("0"); } return 0; }