[八省联考2018]制胡窜 (SAM+大讨论)
正着做着实不太好做,正难则反,考虑反着做。
把i,j看成在切割字符串,我们统计有多少对(i,j)会切割所有与\(s_{l,r}\)相同的串。对于在后缀自动机上表示\(s_{l,r}\)的节点x,x的parent子树内的endpos节点集合,就是和\(s_{l,r}\)相等的串的最后一个字符的出现位置。我们相当于在s串里得到了若干个线段,每个线段表示的子串都和\(s_{l,r}\)相等,然后用两刀把这些串都割了。我们分最左边的串和最右边的串是否存在交集进行讨论。
如果存在交集,线段数量是m
1.第一刀切串[1,i],第二刀切[i+1,m],方案数\((r_{i+1}-r_{i})(r_{i+1}-l_{m})\)
2.第一刀切[1,m],第二刀在第一刀右面随便切,是一个等差数列
3.第一刀切在第一个串左边,第二刀切在交集,一个乘法原理
如果不存在交集
可行的位置收到了限制,我们要求第一刀必须切第一个串,第二刀必须切第m个串,我们讨论出第一刀可行的线段编号区间[L,R],再统计方案数。
总之两种情况都需要维护\(\sum_{i=L}^{R}(r_{i+1}-r_{i})(r_{i+1}-l_{m})\)这个式子,把它拆开。
\[\sum_{i=L}^{R}(r_{i+1}-r_{i})(r_{i+1}-l_{m})
\\=\sum_{i=L}^{R}(\ (r_{i+1}^{2}-r_{i}r_{i+1})-l_{m}(r_{i+1}-r_{i})\ )
\\=\sum_{i=L}^{R}(r_{i+1}^{2}-r_{i}r_{i+1})-l_{m}(r_{R}-r_{L})
\]
常用套路,用线段树合并维护endpos集合,和式第二项维护相邻两项的乘积,对应pushup时左区间max和右区间min,我们需要维护一段区间内最大/最小值,再维护和式即可
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
using namespace std;
template <typename _T> void read(_T &ret)
{
ret=0; _T fh=1; char c=getchar();
while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); }
while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); }
ret=ret*fh;
}
const int N1=1e5+5, S1=N1*2, M1=S1*70, inf=0x3f3f3f3f;
struct EDGE{
int to[S1],nxt[S1],head[S1],cte;
void ae(int u,int v)
{ cte++; to[cte]=v, nxt[cte]=head[u], head[u]=cte; }
}e;
struct node{ ll sum; int mi,ma;
friend node operator + (const node &s1,const node &s2)
{ return (node){s1.sum+s2.sum-((s2.mi!=inf)?1ll*s1.ma*s2.mi:0ll), min(s1.mi,s2.mi) , max(s1.ma,s2.ma)}; }
};
int n,Q;
char str[N1];
int idx(char c){ return c-'0'; }
struct SEG{
int mi[M1],ma[M1],ls[M1],rs[M1],root[S1],tot; ll sum[M1];
void init(){ mi[0]=inf; }
void pushup(int rt)
{
mi[rt]=min(mi[ls[rt]],mi[rs[rt]]);
ma[rt]=max(ma[ls[rt]],ma[rs[rt]]);
sum[rt]=sum[ls[rt]]+sum[rs[rt]];
if(mi[rs[rt]]!=inf) sum[rt]-=1ll*ma[ls[rt]]*mi[rs[rt]];
}
void ins(int x,int l,int r,int &rt)
{
if(!rt) rt=++tot;
if(l==r){ mi[rt]=ma[rt]=l; sum[rt]=1ll*l*l; return; }
int mid=(l+r)>>1;
if(x<=mid) ins(x,l,mid,ls[rt]);
else ins(x,mid+1,r,rs[rt]);
pushup(rt);
}
//位置互不相同 在线段树叶节点一定会return 无需额外特判
int merge(int r1,int r2)
{
if(!r1||!r2) return r1+r2;
int rt=++tot;
ls[rt]=merge(ls[r1],ls[r2]);
rs[rt]=merge(rs[r1],rs[r2]);
pushup(rt);
return rt;
}
int lower(int x,int l,int r,int rt)
{
if(l==r){
if(mi[rt]<=x) return mi[rt];
else return -1;
}
int mid=(l+r)>>1;
if(mi[rs[rt]]<=x) return lower(x,mid+1,r,rs[rt]);
else return lower(x,l,mid,ls[rt]);
}
int upper(int x,int l,int r,int rt)
{
if(l==r){
if(ma[rt]>=x) return ma[rt];
else return -1;
}
int mid=(l+r)>>1;
if(ma[ls[rt]]>=x) return upper(x,l,mid,ls[rt]);
else return upper(x,mid+1,r,rs[rt]);
}
node query(int L,int R,int l,int r,int rt)
{
if(L<=l&&r<=R){
return (node){sum[rt],mi[rt],ma[rt]};
}
int mid=(l+r)>>1; node ans=(node){0ll,inf,0};
if(L<=mid) ans=(ans+query(L,R,l,mid,ls[rt]));
if(R>mid) ans=(ans+query(L,R,mid+1,r,rs[rt]));
return ans;
}
}s;
int trs[S1][10],pre[S1],dep[S1],id[S1],tot,la;
void init(){ tot=la=1; }
void insert(int c,int i)
{
int p=la,np=++tot,q,nq; la=np;
dep[np]=dep[p]+1;
s.ins(i,1,n,s.root[np]); id[i]=np;
for(;p&&!trs[p][c];p=pre[p]) trs[p][c]=np;
if(!p){ pre[np]=1; return; }
q=trs[p][c];
if(dep[q]==dep[p]+1) pre[np]=q;
else{
pre[nq=++tot]=pre[q];
pre[q]=pre[np]=nq;
dep[nq]=dep[p]+1;
memcpy(trs[nq],trs[q],sizeof(trs[nq]));
for(;p&&trs[p][c]==q;p=pre[p]) trs[p][c]=nq;
}
}
int ff[S1][19];
void dfs(int x)
{
for(int j=2;j<=18;j++) ff[x][j]=ff[ ff[x][j-1] ][j-1];
for(int j=e.head[x];j;j=e.nxt[j]){
int v=e.to[j];
dfs(v);
s.root[x]=s.merge(s.root[x],s.root[v]);
}
}
void build()
{
for(int i=2;i<=tot;i++) e.ae(pre[i],i), ff[i][0]=i, ff[i][1]=pre[i];
dfs(1);
}
int main()
{
// freopen("1.in","r",stdin);
read(n); read(Q);
scanf("%s",str+1);
init(); s.init();
for(int i=1;i<=n;i++) insert(idx(str[i]),i);
build();
int l,r,x,len;
for(int q=1;q<=Q;q++){
read(l); read(r); len=r-l+1;
x=id[r];
// for(;dep[pre[x]]<=len;x=pre[x])
for(int j=18;j>=0;j--)
if(dep[ff[x][j]]>=len) x=ff[x][j];
ll ans=1ll*(n-1)*(n-2)/2,tmp=0;
int r1=s.mi[s.root[x]], rm=s.ma[s.root[x]], lm=rm-len+1, l1=r1-len+1;
if(r1>lm){ //s1与sm有交
tmp+=s.sum[s.root[x]]-1ll*r1*r1-1ll*lm*(rm-r1);
tmp+=max(0ll,1ll*(2*n-lm-1-r1)*(r1-lm)/2);
tmp+=max(0ll,1ll*(l1-1)*(r1-lm));
}else{
int L=s.lower(lm,1,n,s.root[x]);
int R=s.lower(r1+len-2,1,n,s.root[x]), lR=R-len+1;
int nxt=s.upper(R+1,1,n,s.root[x]);
if(L!=-1 && r!=-1 && L<=R){
node k=s.query(L,R,1,n,s.root[x]);
tmp+=k.sum-1ll*L*L-1ll*lm*(R-L);
tmp+=1ll*(r1-lR)*(nxt-lm);
}
}
ans-=tmp;
printf("%lld\n",ans);
}
// printf("%llu\n",(sizeof(s)+sizeof(ff)+sizeof(e)+sizeof(trs))/1024/1024);
return 0;
}