题解[P5161 WD与数列]
题意:
给定一个序列,求不相交的长度相同子串中,
满足对应位置之差为同一值的子串对数。
提供一种 \(\text{SAM+dsu on tree}\) 的解法。
-
首先,子串对应位置差为定值,可以将原序列差分后,转化为查询不相交也不相邻的子串对数。
-
把字符串扔到 \(\text{parent}\) 树上,设两个 \(\operatorname{lcs}\) 为 \(len\) 的字符串右端点为 \(r_1,r_2\)
对一个 \(r_1\),考虑其对各个位置的 \(r_2\) 的贡献:
-
也就是说一个 \(r_1\) 对 \(r_2\) 的贡献是一个二维前缀和的形式,
-
反过来,对一个 \(r_2\),\(r_1\) 的贡献为 \(\begin{cases}r_2-r_1-1,r_1\in[r_2-len-1,r_2-2]\\len,r_2\in[1,r_2-len-2]\end{cases}\)
接下来用 \(\text{dsu on tree}\) 解决全局相互间的贡献。
主要思想是考虑重儿子与轻儿子内的 \(r_1\) 对轻儿子内的 \(r_2\),轻儿子内的 \(r_1\) 对重儿子内的 \(r_2\) 的影响。
-
重儿子中的 \(r_1\) 可以在 \(\text{dsu on tree}\) 过程中统计,统计时先将全部轻儿子子树的贡献加入。
查某轻儿子内的子树时,先去除其子树内的贡献,之后再加回来。
对每个 \(r_2\),按上式算出需要区间内的 \(r_1\) 的个数与和,可以用树状数组实现。 -
轻链对重儿子子树内影响的点是 \(dfs\) 序上连续的一段区间,可以将这些影响离线下来。
按 \(dfs\) 序枚举点以及顺带的修改,把能影响 \(r_2\) 的 \(r_1\) 统计出二维前缀和的贡献,依然能用树状数组。
总时间复杂度为 \(O(n\log^2n)\),但跑得十分快,\(\text{Subtask 3}\) 中仅有一个点跑上 \(1s\)。
附带:对这种形式的 \(\text{dsu on tree}\) 的更多扩展应用。
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=5e5+10,M=5e6+10;
int n,m,x,xx,y,tmp,last=1,tot=1;
int len_lca,tt,ldfn,rdfn,qtot;
ll ans,res;char ch;bool rf;
unordered_map<int,int>trans[N];
int mxlen[N],link[N],pos[N],rev[N],a[N];
inline void read(int &x){
x=rf=0;ch=getchar();while((ch<48)&&(ch^'-'))ch=getchar();
if(ch=='-')rf=1,ch=getchar();
while(ch>47)x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(rf)x=-x;
}
void write(ll x){if(x>9)write(x/10);putchar(48+x%10);}
inline void extend(int a){
tmp=++tot;y=last;mxlen[tmp]=mxlen[y]+1;
for(;y&&!trans[y].count(a);y=link[y])trans[y][a]=tmp;
if(!y)link[tmp]=1;
else {
x=trans[y][a];
if(mxlen[x]==mxlen[y]+1)link[tmp]=x;
else {
xx=++tot;link[xx]=link[x];
mxlen[xx]=mxlen[y]+1;trans[xx]=trans[x];
for(;trans[y].count(a)&&trans[y][a]==x;y=link[y])trans[y][a]=xx;
link[tmp]=link[x]=xx;
}
}
last=tmp;
}
int nextn[N],to[N],h[N],edg;
int nextnq[M],toq[M],hq[N],edgq;
struct mdf{
int l,r,t,v;mdf()=default;
mdf(int _l,int _r,int _t,int _v):l(_l),r(_r),t(_t),v(_v){}
}q[M];
int son[N],sz[N],dfn[N],r_dfn[N];ll w[N],tmpp,tmp_;
inline void add(int x,int y){to[++edg]=y,nextn[edg]=h[x],h[x]=edg;}
inline void addq(int x,int y){toq[++edgq]=y,nextnq[edgq]=hq[x],hq[x]=edgq;}
inline void init(int x){
int i,y;sz[x]=1;r_dfn[dfn[x]=++tt]=x;
for(i=h[x];y=to[i];i=nextn[i]){
init(y);sz[x]+=sz[y];
if(sz[y]>sz[son[x]])son[x]=y;
}
}
int *t[N],t__[N<<1],*t_,_t,_t_,num[N];
#define lowbit(i) i&(-i)
struct bit{
ll t[N],res;int ti;inline void clear(){for(ti=1;ti<=n;++ti)t[ti]=0;}
inline void update(int pos,int v){for(ti=pos;ti<=n;ti+=lowbit(ti))t[ti]+=v;}
inline void inquiry(int pos){res=0;for(ti=pos;ti;ti-=lowbit(ti))res+=t[ti];}
}T,T_;
inline void T_add(int pos){
q[++qtot]=mdf(pos+2,pos+len_lca+2,ldfn,1);
q[++qtot]=mdf(pos+2,pos+len_lca+2,rdfn,-1);
}
inline void T_upd(int pos,int v){T.update(pos,v);T_.update(pos,pos*v);}
inline void T_inq(int pos){T.inquiry(pos);T_.inquiry(pos);res=T.res*(pos+1)-T_.res;}//二维前缀和
inline void update(int pos,int v){T.update(pos,v);T_.update(pos,pos*v);}
inline void inquiry(int pos){
T.inquiry(pos-2);tmpp=T.res;T.inquiry(max(pos-len_lca-2,0));tmpp-=T.res;
T_.inquiry(pos-2);tmp_=T_.res;T_.inquiry(max(pos-len_lca-2,1));tmp_-=T_.res;
res=tmpp*(pos-1)-tmp_;
T.inquiry(max(pos-len_lca-2,0));res+=T.res*len_lca;
}//对轻链的影响
void clear(int x){
int i,y;if(pos[x])update(pos[x],-1);
for(i=h[x];y=to[i];i=nextn[i])clear(y);
}
void dfs(int x){
if(pos[x]){
if(ldfn)T_add(pos[x]);
t[_t][++_t_]=x;
}
int i,y;
for(i=h[x];y=to[i];i=nextn[i])dfs(y);
}
void solve(int x){
int i,y;
for(i=h[x];y=to[i];i=nextn[i])if(y^son[x])solve(y),clear(y);
if(y=son[x])solve(y);len_lca=mxlen[x];
ldfn=dfn[y],rdfn=dfn[y]+sz[y];
if(len_lca){
if(pos[x]){
if(ldfn)T_add(pos[x]);
update(pos[x],1);
}
int j;t_=t__;_t=0;
for(i=h[x];y=to[i];i=nextn[i])if(y^son[x]){
_t_=0;t[++_t]=t_;t_+=sz[y]+1;
dfs(y);num[_t]=_t_;
}
for(i=1;i<=_t;++i)for(j=1;j<=num[i];++j)update(pos[t[i][j]],1);
for(i=1;i<=_t;++i){
for(j=1;j<=num[i];++j)update(pos[t[i][j]],-1);
for(j=1;j<=num[i];++j)if(pos[y=t[i][j]]>2)inquiry(pos[y]),w[y]+=res;
for(j=1;j<=num[i];++j)update(pos[t[i][j]],1);
}
}
}//dsu on tree
void work(){
register int i,qi;
for(i=1;i<=n;++i)a[i]=a[i+1]-a[i];--n;
for(i=1;i<=n;++i)extend(a[i]),pos[tmp]=i,rev[i]=tmp;
for(i=2;i<=tot;++i)add(link[i],i);
init(1);solve(1);T.clear();T_.clear();
for(i=1;i<=qtot;++i)if(q[i].t<=tot)addq(q[i].t,i);
for(i=1;i<=tot;++i){
x=r_dfn[i];
for(qi=hq[i];y=toq[qi];qi=nextnq[qi]){
q[y].l=max(q[y].l,1);
T_upd(q[y].l,q[y].v);
T_upd(q[y].r,-q[y].v);
}
if(pos[x])T_inq(pos[x]),w[x]+=res;
}//离线统计
for(i=1;i<=tot;++i)ans+=w[i];
}
main(){
read(n);register int i;
for(i=1;i<=n;++i)read(a[i]);
ans=1ll*n*(n-1)>>1;//每个长度为1的子串间的贡献
work();write(ans);
}