luogu P4384 [八省联考2018]制胡窜
先特判\(l=r\)的答案为\(\binom{n-1}{2}\).这里的枚举\(i,j\)等价于把原串分成三个非空段,然后题目要求的是\(s_{l,r}\)至少在一个段中出现,不妨考虑求总方案数\(\binom{n-1}{2}\)减去\(s_{l,r}\)不在任何一段中出现的方案
把原串分成三个非空段等价于选择两个位置\(i<j<n\),在\(i\)和\(i+1\)之间,\(j\)和\(j+1\)之间画竖线(后面称在\(i\)和\(i+1\)之间的竖线为在\(i\)位置的竖线),那么要使得\(s_{l,r}\)不在整段中出现,相当于对于所有\(k\)满足\(s_{k-(r-l),k}=s_{l,r}\),存在一条在\([k-(r-l),k-1]\)之间的竖线.所以把\(s_{l,r}\)结束位置集合\(endpos\)扣出来,现在问题变成有若干个长度为\(len=r-l\)的区间,要选择两个位置,使得所有区间至少包含两个位置中的一个
后面记\(pl,pr\)分别为\(s_{l,r}\)的\(endpos\)最小和最大的位置.对于这个问题,可以分两种情况.第一种是存在一个位置经过所有区间的交,那么只需要知道这个交的大小\(a=(pl-1)-(pr-len)\)就可以简单算出答案了,大概为\((n-2)a-\binom{a}{2}\)
第二种就相当于是先后选择两个位置\(i,j\),使得可以把区间分成两个集合,并且分别包含一个位置,这里会有\(j\in[\max(pl,pr-len),\min j(j\in endpos(s_{l,r}),j\ge pl+ln)]\),前者是因为你不能把\(j\)设为前面在所有区间的交中统计过的位置,后者是你必须包含 最左边的 与最左边区间没有交 的区间
记\(j\)的合法区间左右端点为\(ql,qr\).然后如果从左往右枚举\(j\),可选的\(i\)的数量是不增的,并且可选\(i\)数量相同的\(j\)会构成若干区间废话,具体的,设\(endpos(s_{l,r})\)中在\([ql,qr]\)中有元素\(j_1,j_2...j_m\),那么如果\(j\in[j_k,j_{k+1})\),可选的\(i\)个数都为\(pl+len-j_k\),由于\(pl+len\)是常量,所以利用线段树维护出区间\([j_1,j_{k+1})\)这些位置的\(\sum j_k\)贡献
然后还剩\([ql,j_1)\),\([j_{k+1},qr]\)这两个区间,可以发现找到\([pl,ql)\)中在最大的\(endpos\)的元素\(j_0\)后,两个区间中每个位置可选的\(i\)个数分别为\(pl+len-j_0,pl+len-j_m\).注意这里\(i\)范围会包含前面算过的所有区间交范围,所以还要减去\((qr-ql+1)a\).还有一种情况是\(endpos(s_{l,r})\)没有元素在\([ql,qr]\)中,这时直接给总方案减去\((qr-ql+1)(pl+len-j_0)\),因为这时所有区间交显然为空
具体细节见代码
#include<bits/stdc++.h>
#define LL long long
#define uLL unsigned long long
#define db double
using namespace std;
const int N=2e5+10;
int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+(ch^48);ch=getchar();}
return x*w;
}
char cc[N];
int n,q,lz,rt[N];
struct node
{
int l,r;
LL s;
node(){l=N+1,r=0,s=0;}
node operator + (const node &bb) const
{
node an;
an.l=min(l,bb.l),an.r=max(r,bb.r),an.s=s+bb.s;
if(l<=r&&bb.l<=bb.r&&r<bb.l) an.s+=1ll*r*(bb.l-r);
return an;
}
}s[N*50];
namespace smr
{
int ch[N*50][2],tt;
void inst(int o,int x)
{
s[o].l=s[o].r=x;
int l=1,r=n;
while(l<r)
{
int mid=(l+r)>>1;
if(x<=mid) o=ch[o][0]=++tt,r=mid;
else o=ch[o][1]=++tt,l=mid+1;
s[o].l=s[o].r=x;
}
}
int merg(int o1,int o2)
{
if(!o1||!o2) return o1+o2;
int o=++tt;
ch[o][0]=merg(ch[o1][0],ch[o2][0]);
ch[o][1]=merg(ch[o1][1],ch[o2][1]);
s[o]=s[ch[o][0]]+s[ch[o][1]];
return o;
}
node quer(int o,int l,int r,int ll,int rr)
{
if(!o||ll>rr) return s[0];
if(ll<=l&&r<=rr) return s[o];
int mid=(l+r)>>1;
if(rr<=mid) return quer(ch[o][0],l,mid,ll,rr);
if(ll>mid) return quer(ch[o][1],mid+1,r,ll,rr);
return quer(ch[o][0],l,mid,ll,mid)+quer(ch[o][1],mid+1,r,mid+1,rr);
}
}
namespace sam
{
int to[N],nt[N],hd[N],tot=1;
void adde(int x,int y){++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;}
int fa[N],ff[N][18],ps[N],de[N],tn[N][10],len[N],tt=1,la=1;
void extd(int x,int i)
{
int np=++tt,p=la;
len[np]=len[p]+1,ps[i]=np,smr::inst(rt[np]=++(smr::tt),i),la=np;
while(p&&!tn[p][x]) tn[p][x]=np,p=fa[p];
if(!p) fa[np]=1;
else
{
int q=tn[p][x];
if(len[q]==len[p]+1) fa[np]=q;
else
{
int nq=++tt;
fa[nq]=fa[q],len[nq]=len[p]+1;
memcpy(tn[nq],tn[q],sizeof(int)*10),fa[np]=fa[q]=nq;
while(p&&tn[p][x]==q) tn[p][x]=nq,p=fa[p];
}
}
}
void dfs(int x)
{
ff[x][0]=fa[x];
for(int j=1;j<=lz;++j) ff[x][j]=ff[ff[x][j-1]][j-1];
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
de[y]=de[x]+1,dfs(y),rt[x]=smr::merg(rt[x],rt[y]);
}
}
void inii()
{
for(int i=1;i<=n;++i) extd(cc[i]-'0',i);
for(int i=2;i<=tt;++i) adde(fa[i],i);
dfs(1);
}
void wk()
{
int zl=rd(),zr=rd(),ln=zr-zl+1,x=ps[zr];
LL an=1ll*(n-1)*(n-2)/2;
if(zl==zr){printf("%lld\n",an);return;}
for(int j=lz;~j;--j)
if(len[ff[x][j]]>=ln) x=ff[x][j];
node nw=s[rt[x]];
int pl=nw.l,pr=nw.r,y=0;
if(pl==pr){printf("%lld\n",an-(1ll*(n-2)*(ln-1)-1ll*(ln-1)*(ln-2)/2));return;}
if(pr-ln+1<=pl-1)
{
y=(pl-1)-(pr-ln+1)+1;
an-=(1ll*(n-2)*y-1ll*y*(y-1)/2);
}
int ql=max(pl,pr-ln+1),qr=min(pr-1,smr::quer(rt[x],1,n,pl+ln-1,pr).l-1);
if(ql>qr){printf("%lld\n",an);return;}
node nx=smr::quer(rt[x],1,n,ql,qr);
int v1=pl-(smr::quer(rt[x],1,n,pl,ql).r-ln+1);
if(!nx.r){printf("%lld\n",an-1ll*(qr-ql+1)*v1);return;}
an-=1ll*(nx.l-ql)*v1;
an-=1ll*(nx.r-nx.l)*(pl+ln-1)-nx.s;
int v2=max(pl+ln-1-nx.r,0);
an-=1ll*(qr-nx.r+1)*v2;
an+=1ll*(qr-ql+1)*y;
printf("%lld\n",an);
}
}
int main()
{
n=rd(),q=rd(),scanf("%s",cc+1),lz=log2(n+n);
sam::inii();
while(q--) sam::wk();
return 0;
}