LOJ#6031. 「雅礼集训 2017 Day1」字符串 根号分治+SAM+倍增
怎么想都没想出来 $\log n$ 做法,那么这道题基本就是根号分治了.
题目描述中保证 $\sum k \leqslant 10^5$,然后 $k$ 在每次询问中又是相同的,那么就考虑对 $k$ 根号分治.
先对 $s$ 建立后缀自动机,然后把倍增数组求出来.
我们设块的大小为 $B$,那么当 $k \leqslant B$ 时可以对 $k$ 的每一个子串在 $s$ 上都求一遍出现次数 (暴力跳祖先).
其中对于编号为 $[a,b]$ 的限制我们可以直接开一个二维的 vector 存储查询为 $(l,r)$ 的编号,然后 lowerbound 一下就行.
这个的复杂度是 $O(k^2 \log n)$ 的,总复杂度是 $O(k^2 Q \log n)$,即 $O(B n\log n)$.
对于 $k>B$ 时询问次数不会超过 $\frac{10^5}{B}$ 个,那么可以直接对询问按照右端点离线,然后将 $w$ 在 $s$ 上匹配,最后再倍增一下.
这部分的复杂度是 $O(\frac{Q}{k} n \log n)$ 的.
这个 $B$ 取到 400 或 $\sqrt n$ 即可.
code:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 100009 #define ll long long #define pb push_back #define setIO(s) freopen(s".in","r",stdin) using namespace std; const int B=403; char str[N]; int n,m,Q,k,tot,last,edges; ll cnt[N<<1]; int pre[N<<1],ch[N<<1][27],mx[N<<1]; int hd[N<<1],to[N<<1],nex[N<<1],fa[20][N<<1]; void add(int u,int v) { nex[++edges]=hd[u]; hd[u]=edges,to[edges]=v; } struct oper { int l,r,id; bool operator<(const oper b) const { return r<b.r; } }a[N]; void init() { last=tot=1; } void extend(int c) { int np=++tot,p=last; mx[np]=mx[p]+1,last=np; for(;p&&!ch[p][c];p=pre[p]) { ch[p][c]=np; } if(!p) { pre[np]=1; } else { int q=ch[p][c]; if(mx[q]==mx[p]+1) pre[np]=q; else { int nq=++tot; mx[nq]=mx[p]+1; pre[nq]=pre[q],pre[np]=pre[q]=nq; memcpy(ch[nq],ch[q],sizeof(ch[q])); for(;p&&ch[p][c]==q;p=pre[p]) { ch[p][c]=nq; } } } ++cnt[np]; } void dfs(int x) { fa[0][x]=pre[x]; for(int i=1;i<20;++i) fa[i][x]=fa[i-1][fa[i-1][x]]; for(int i=hd[x];i;i=nex[i]) { dfs(to[i]); cnt[x]+=cnt[to[i]]; } } int len; int trans(int x,int c) { while(x&&!ch[x][c]) x=pre[x],len=mx[x]; if(ch[x][c]) { x=ch[x][c],++len; return x; } else return 1; } namespace sol1 { vector<int>q[B][B]; int main() { for(int i=1;i<=m;++i) { if(a[i].r<=k) { q[a[i].l][a[i].r].pb(i); } } int x,y,z; for(int T=1;T<=Q;++T) { scanf("%s%d%d",str+1,&x,&y); ++x,++y; ll cur=0; z=1,len=0; for(int j=1;j<=k;++j) { z=trans(z,str[j]-'a'); for(int p=z,o=len;o;--o) { // [j-len+1,j] int l=j-o+1; int a1=lower_bound(q[l][j].begin(),q[l][j].end(),x)-q[l][j].begin(); int a2=upper_bound(q[l][j].begin(),q[l][j].end(),y)-q[l][j].begin(); cur+=(a2-a1)*cnt[p]; if(o-1==mx[pre[p]]) p=pre[p]; } } printf("%lld\n",cur); } return 0; } }; int get_up(int x,int kth) { for(int i=19;i>=0;--i) { if(mx[fa[i][x]]>=kth) { x=fa[i][x]; } } return x; } namespace sol2 { int main() { int x,y,z; for(int i=1;i<=m;++i) a[i].id=i; sort(a+1,a+1+m); for(int T=1;T<=Q;++T) { scanf("%s%d%d",str+1,&x,&y); ++x,++y; z=1,len=0; ll cur=0; int lst=1; for(int j=1;j<=k;++j) { z=trans(z,str[j]-'a'); while(a[lst].r<=j&&lst<=m) { if(a[lst].r-a[lst].l+1<=len&&a[lst].id>=x&&a[lst].id<=y) { int p=get_up(z,a[lst].r-a[lst].l+1); cur+=cnt[p]; } ++lst; } } printf("%lld\n",cur); } return 0; } }; int main() { // setIO("input"); // freopen("input.out","w",stdout); scanf("%d%d%d%d%s",&n,&m,&Q,&k,str+1); init(); for(int i=1;i<=n;++i) { extend(str[i]-'a'); } for(int i=2;i<=tot;++i) { add(pre[i],i); } dfs(1); for(int i=1;i<=m;++i) { scanf("%d%d",&a[i].l,&a[i].r); ++a[i].l; ++a[i].r; } if(k<403) sol1::main(); else sol2::main(); return 0; }