Loj #2479. 「九省联考 2018」制胡窜
Loj #2479. 「九省联考 2018」制胡窜
题目描述
对于一个字符串 \(S\),我们定义 \(|S|\) 表示 \(S\) 的长度。
接着,我们定义 \(S_i\) 表示 \(S\) 中第 \(i\) 个字符,\(S_{L,R}\) 表示由 \(S\) 中从左往右数,第 \(L\) 个字符到第 \(R\) 个字符依次连接形成的字符串。特别的,如果 \(L > R\) ,或者 \(L < [1, |S|]\), 或者 \(R < [1, |S|]\) 我们可以认为 \(S_{L,R}\) 为空串。
给定一个长度为 \(n\) 的仅由数字构成的字符串 \(S\),现在有 \(q\) 次询问,第 \(k\) 次询问会给出 \(S\) 的一个字符串 \(S_{l,r}\) ,请你求出有多少对 \((i, j)\),满足 \(1 \le i < j \le n\),\(i + 1 \lt j\),且 \(S_{l,r}\) 出现在 \(S_{1,i}\) 中或 \(S_{i+1, j−1}\) 中或 \(S_{j,n}\) 中。
输入格式
输入的第一行包含两个整数 \(n, q\)。
第二行包含一个长度为 \(n\) 的仅由数字构成的字符串 \(S\)。
接下来 \(q\) 行,每行两个正整数 \(l\) 和 \(r\),表示此次询问的子串是 \(S_{l,r}\)。
输出格式
对于每个询问,输出一个整数表示合法的数对个数。
数据范围与提示
对于所有测试数据,\(1 \le n \le 10^5\),\(1 \le q \le 3 · 10^5\),\(1 \le l \le r \le n\)。
\(\\\)
感觉这道题细节贼烦人,正式考试的话估计可以刚一整场。
首先建后缀自动机,然后在使用线段树合并维护\(endpos\)集合。
询问的时候就先在\(fail\)树上倍增找到给定字符串出现的节点。然后我们将合法的\((i,j)\)二元组分为以下三种情况:
- \(S_{1,i}\)中出现
- \(S_{1,i}\)中未出现,\(S_{j,n}\)中出现
- \(S_{1,i},S_{j,n}\)中为出现,\(S_{i+1,j-1}\)中出现。
前两种情况很好算,找到位置最靠前以及最靠后的\(endpos\)就行了。
下面来考虑第三种情况。假设最靠前的\(endpos\)是\(L\),最靠后的是\(R\),字符串长度为\(len\)。显然\(i<L,j>R-len+1\)。
我们先考虑一种暴力做法:枚举\(j\in[R-len+2,n]\),然后算对于每个\(j\)有多少个可行的\(i\)。设\(<j\)的最大的\(endpos\)为\(mx\),显然可行的\(i\)只与\(mx\)有关,为\(\min\{L,mx-len\}\)。
理解了这个暴力做法过后正解就差不多知道了。对于线段树上每个节点,我们令每个位置的权值为其左边第一个\(endpos\)(如果没有则为\(0\)),\(sum\)为这些位置的权值和,\(rmax\)为最右边的\(endpos\),\(lempty\)为左边有多少个位置没有\(endpos\)。注意上述的信息只考虑了线段树所表示的区间,区间外的\(endpos\)不对其产生任何影响。正因为如此,在询问的时候先遍历左儿子,动态更新最右边的\(endpos\),再遍历右儿子计算答案。
道理很简单,就是要注意的边界情况有点多。。。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 200005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
int n,m;
char s[N];
int fail[N<<1],mxlen[N<<1];
int ch[N<<1][10];
int last=1,cnt=1;
int pos[N<<1],id[N<<1];
ll ss[N];
void Insert(int f,int P) {
int p=last;
int v=++cnt;
pos[v]=P;
id[P]=v;
last=v;
mxlen[v]=mxlen[p]+1;
while(p&&!ch[p][f]) ch[p][f]=v,p=fail[p];
if(!p) return fail[v]=1,void();
int sn=ch[p][f];
if(mxlen[sn]==mxlen[p]+1) return fail[v]=sn,void();
int New=++cnt;
mxlen[New]=mxlen[p]+1;
memcpy(ch[New],ch[sn],sizeof(ch[sn]));
fail[New]=fail[sn];
fail[sn]=fail[v]=New;
while(p&&ch[p][f]==sn) ch[p][f]=New,p=fail[p];
}
int fa[N<<1][20];
vector<int>e[N<<1];
int rt[N<<1];
int ls[N*50],rs[N*50];
int tag[N*50];
int emp[N*50],rmax[N*50];
ll sum[N*50];
int tot;
int lx,rx;
void update(int v,int lx,int rx) {
sum[v]=sum[ls[v]]+sum[rs[v]];
int mid=lx+rx>>1;
ll R=rs[v]?emp[rs[v]]:rx-mid;
sum[v]+=1ll*rmax[ls[v]]*R;
if(!ls[v]||emp[ls[v]]==mid-lx+1) {
emp[v]=mid-lx+1+R;
} else {
emp[v]=emp[ls[v]];
}
if(rs[v]) rmax[v]=rmax[rs[v]];
else rmax[v]=rmax[ls[v]];
}
void Insert(int &v,int lx,int rx,int p) {
v=++tot;
tag[v]=1;
if(lx==rx) {
sum[v]=p;
rmax[v]=lx;
return ;
}
int mid=lx+rx>>1;
if(p<=mid) Insert(ls[v],lx,mid,p);
else Insert(rs[v],mid+1,rx,p);
update(v,lx,rx);
}
int Merge(int a,int b,int lx,int rx) {
if(!a||!b) return a+b;
int v=++tot;
int mid=lx+rx>>1;
ls[v]=Merge(ls[a],ls[b],lx,mid);
rs[v]=Merge(rs[a],rs[b],mid+1,rx);
update(v,lx,rx);
return v;
}
void dfs(int v) {
for(int i=1;i<=18;i++) fa[v][i]=fa[fa[v][i-1]][i-1];
if(pos[v]) Insert(rt[v],lx,rx,pos[v]);
for(int i=0;i<e[v].size();i++) {
int to=e[v][i];
dfs(to);
rt[v]=Merge(rt[v],rt[to],lx,rx);
}
}
int Find(int l,int r) {
int v=id[r];
for(int i=18;i>=0;i--)
if(fa[v][i]&&mxlen[fa[v][i]]>=r-l+1)
v=fa[v][i];
return v;
}
int query_mn(int v,int lx,int rx,int lim) {
if(!v||rx<lim) return 0;
if(lx==rx) return lx;
int mid=lx+rx>>1;
int x=query_mn(ls[v],lx,mid,lim);
if(x) return x;
else return query_mn(rs[v],mid+1,rx,lim);
}
int query_mx(int v,int lx,int rx) {
if(lx==rx) return lx;
int mid=lx+rx>>1;
if(rs[v]) return query_mx(rs[v],mid+1,rx);
else return query_mx(ls[v],lx,mid);
}
ll query_s(int v,int lx,int rx,int l,int r,int &L) {
if(lx>r) return 0;
if(rx<l) {
L=max(L,rmax[v]);
return 0;
}
if(l<=lx&&rx<=r) {
ll x=!v?rx-lx+1:emp[v];
ll ans=sum[v]+1ll*x*L;
L=max(L,rmax[v]);
return ans;
}
int mid=lx+rx>>1;
return query_s(ls[v],lx,mid,l,r,L)+query_s(rs[v],mid+1,rx,l,r,L);
}
ll solve(int v,int len) {
int mn=query_mn(rt[v],lx,rx,1),mx=query_mx(rt[v],lx,rx);
ll ans=0;
if(mn==mx) {
if(mx<n) ans+=ss[n-mx-1];
if(mn-len+1>1) ans+=ss[mn-len-1];
ans+=1ll*(n-mx)*(mn-len);
return ans;
}
if(mn<n) ans+=ss[n-mn-1];
if(mx-len+1>1) ans+=ss[mx-len-1];
if(mx-len+1>mn+1) ans-=ss[mx-len-mn];
int ed=query_mn(rt[v],lx,rx,mn+len-1);
if(ed) {
ed=max(ed,mx-len+1);
ans+=1ll*(n-ed)*(mn-1);
ed--;
} else ed=n-1;
int st=max(mn,mx-len+1);
if(ed>=st) {
int L=0;
ans+=query_s(rt[v],lx,rx,st,ed,L);
ans-=1ll*len*(ed-st+1);
}
return ans;
}
int main() {
n=Get(),m=Get();
for(int i=1;i<=n;i++) ss[i]=ss[i-1]+i;
lx=1,rx=n;
scanf("%s",s+1);
for(int i=1;i<=n;i++) Insert(s[i]-'0',i);
for(int i=2;i<=cnt;i++) {
e[fail[i]].push_back(i);
fa[i][0]=fail[i];
}
dfs(1);
int l,r;
while(m--) {
l=Get(),r=Get();
cout<<solve(Find(l,r),r-l+1)<<"\n";
}
return 0;
}