【codeforces 666E】 Forensic Examination
http://codeforces.com/problemset/problem/666/E (题目链接)
题意
给出一个主串$S$,$n$个匹配串编号从$1$到$n$。$m$组询问,每次询问主串的一个子串$S[p_l,p_r]$在编号为$[l,r]$的匹配串的哪一个中出现次数最多。
Solution
首先我们把匹配串相连,中间用分隔符隔开,构建后缀自动机,并且给自动机主链上的节点打上标记,避免之后找后缀的时候重复计算。
每个节点与其对应$parent$相连,构建$parent$树。由于$parent$树的性质,那么问题就转化为了对于每一个询问求解在$parent$树上对应$S[p_l,p_r]$的节点子树中,编号为$[l,r]$的节点的众数。那么我们现在就要解决两个问题:第一,如何知道对于每一个询问$S[p_l,p_r]$,它在自动机上匹配以后到达的节点位置;第二,如何求解子树中编号为$[l,r]$的节点的众数。
第一个问题,我们将整个主串$S$在自动机上匹配,那么我们可以得到一个$pos$数组,$pos[i]$表示主串$S$匹配到第$i$位时在自动机上的节点。于是我们就知道了对应子串$S[1,1],S[1,2],S[1,3],S[1,4]······S[1,|S|]$这些子串在自动机上的对应位置。那么对于每一个$S[p_l,p_r]$,我们根据右端点找到$pos[p_r]$,因为长度变短,所以它的实际位置可能是$pos[p_r]$的$parent$树上的祖先,所以我们倍增跳$parent$就可以了。
第二个问题,考虑线段树合并。按照一定的顺序线段树合并,确保随着每个节点的处理完成,位置在这个节点的询问也会跟随着一起处理掉。我们求出$parent$树的$dfs$序,将询问按照其位置的$dfs$序升序排列,然后在树上一边走一边线段树合并,在树上走的时候按照$dfs$序从大往小的顺序走。同时按照从后往前的顺序处理询问即可。
细节
这里线段树的询问返回值用的是pair,感觉应该更好处理一些。出现次数为0记得要特判一下。
代码
// codeforces 666E #include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<cstdio> #include<vector> #include<cmath> #define LL long long #define inf (1ll<<30) #define Pi acos(-1.0) #define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout) using namespace std; const int maxn=500010; int n,m,Q,tot,cnt,rt[maxn],dfn[maxn],fa[maxn][30],id[maxn],Len[maxn],r[maxn]; char s[maxn],st[maxn],ss[maxn]; vector<int> v[maxn]; struct data {int l,r,sl,sr,n,id;}q[maxn],t[maxn]; struct Pair { int x,y; friend bool operator < (Pair a,Pair b) {return a.x==b.x ? a.y>b.y : a.x<b.x;} }ans[maxn]; struct node { int son[2];Pair sum; int& operator [] (int x) {return son[x];} }tr[maxn*40]; namespace SAM { int last; int par[maxn<<1],ch[maxn<<1][27],len[maxn<<1],pos[maxn]; void Extend(int c,int x) { int np=++m,p=last;last=np; len[np]=len[p]+1;id[np]=x;r[np]=1; for (;p && !ch[p][c];p=par[p]) ch[p][c]=np; if (!p) par[np]=1; else { int q=ch[p][c]; if (len[q]==len[p]+1) par[np]=q; else { int nq=++m;len[nq]=len[p]+1;id[nq]=id[q]; memcpy(ch[nq],ch[q],sizeof(ch[q])); par[nq]=par[q]; par[np]=par[q]=nq; for (;p && ch[p][c]==q;p=par[p]) ch[p][c]=nq; } } } void build(char *r) { int len=strlen(r+1),id=1; last=m=1; for (int i=1;i<=len;i++) { Extend(r[i]-'a',id); if (r[i]==123) id++; } for (int i=2;i<=m;i++) v[par[i]].push_back(i); } void match(char *r) { int p=1,e=strlen(r+1),l=0; for (int i=1;i<=e;i++) { while (p>1 && !ch[p][r[i]-'a']) p=par[p],l=len[p]; if (ch[p][r[i]-'a']) pos[i]=p=ch[p][r[i]-'a'],Len[i]=++l; } } void position() { for (int i=1;i<=Q;i++) { int p=pos[q[i].sr]; int tmp=q[i].sr-q[i].sl+1; if (Len[q[i].sr]<tmp) continue; t[++tot]=q[i]; for (int j=20;j>=0;j--) if (len[fa[p][j]]>=tmp) p=fa[p][j]; t[tot].n=p; } } } using namespace SAM; namespace Segtree { int sz; void insert(int &k,int l,int r,int x) { if (!k) k=++sz; if (l==r) {tr[k].sum=(Pair){1,l};return;} int mid=(l+r)>>1; if (x<=mid) insert(tr[k][0],l,mid,x); else insert(tr[k][1],mid+1,r,x); tr[k].sum=max(tr[tr[k][0]].sum,tr[tr[k][1]].sum); } int merge(int x,int y,int l,int r) { if (!x || !y) return x|y; int mid=(l+r)>>1; if (l==r) {tr[x].sum.x+=tr[y].sum.x;return x;} tr[x][0]=merge(tr[x][0],tr[y][0],l,mid); tr[x][1]=merge(tr[x][1],tr[y][1],mid+1,r); tr[x].sum=max(tr[tr[x][0]].sum,tr[tr[x][1]].sum); return x; } Pair query(int k,int s,int t,int l,int r) { if (s==l && r==t) return tr[k].sum; int mid=(l+r)>>1; if (t<=mid) return query(tr[k][0],s,t,l,mid); else if (s>mid) return query(tr[k][1],s,t,mid+1,r); else return max(query(tr[k][0],s,mid,l,mid),query(tr[k][1],mid+1,t,mid+1,r)); } } using namespace Segtree; namespace Basic { bool cmp(data a,data b) {return dfn[a.n]<dfn[b.n];} void dfs(int x) { dfn[x]=++cnt; for (int i=1;i<=20;i++) fa[x][i]=fa[fa[x][i-1]][i-1]; int len=v[x].size(); for (int i=0;i<len;i++) { fa[v[x][i]][0]=x; dfs(v[x][i]); } } void work(int x) { int len=v[x].size(); reverse(v[x].begin(),v[x].end()); for (int i=0;i<len;i++) work(v[x][i]); if (r[x]) insert(rt[x],1,m,id[x]); for (int i=0;i<len;i++) rt[x]=merge(rt[x],rt[v[x][i]],1,m); for (;tot && t[tot].n==x;tot--) ans[t[tot].id]=query(rt[x],t[tot].l,t[tot].r,1,m); } } using namespace Basic; int main() { scanf("%s",s+1); scanf("%d",&n); int len=0; for (int i=1;i<=n;i++) { scanf("%s",ss+1); int l=strlen(ss+1); for (int j=1;j<=l;j++) st[++len]=ss[j]; st[++len]=123; } build(st); match(s); scanf("%d",&Q); for (int i=1;i<=Q;i++) scanf("%d%d%d%d",&q[i].l,&q[i].r,&q[i].sl,&q[i].sr),q[i].id=i; dfs(1); position(); sort(t+1,t+1+tot,cmp); work(1); for (int i=1;i<=Q;i++) if (!ans[q[i].id].x) ans[q[i].id].y=q[i].l; for (int i=1;i<=Q;i++) printf("%d %d\n",ans[i].y,ans[i].x); return 0; }