luoguP5161 WD与数列 后缀自动机+线段树合并+启发式合并
第一次写这个题是好长时间以前了,然后没调出来.
本来以为是思路错了,结果今天看题解发现思路没错,但是好多代码细节需要注意.
code:
#include <cstdio> #include <vector> #include <map> #include <cstring> #include <algorithm> #define N 600008 #define ll long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; int total; int rt[N],arr[N],rk[N],bu[N],id[N]; vector<int>G[N]; namespace seg { int tot; int newnode() { return ++tot; } struct data { int ls,rs,sum1; ll sum2; data() { ls=rs=sum1=sum2=0; } data operator+(const data &b) const { data c; c.sum1=sum1+b.sum1; c.sum2=sum2+b.sum2; return c; } }s[N*20],bl; void update(int &x,int l,int r,int p,int v) { if(!x) x=newnode(); ++s[x].sum1; s[x].sum2+=v; if(l==r) return; int mid=(l+r)>>1; if(p<=mid) update(s[x].ls,l,mid,p,v); else update(s[x].rs,mid+1,r,p,v); } data query(int x,int l,int r,int L,int R) { if(!x||r<L||l>R||L>R) return bl; if(l>=L&&r<=R) return s[x]; int mid=(l+r)>>1; if(L<=mid&&R>mid) return query(s[x].ls,l,mid,L,R)+query(s[x].rs,mid+1,r,L,R); else if(L<=mid) return query(s[x].ls,l,mid,L,R); else return query(s[x].rs,mid+1,r,L,R); } int merge(int x,int y) { if(!x||!y) return x+y; int now=newnode(); s[now].sum1=s[x].sum1+s[y].sum1; s[now].sum2=s[x].sum2+s[y].sum2; s[now].ls=merge(s[x].ls,s[y].ls); s[now].rs=merge(s[x].rs,s[y].rs); return now; } }; namespace sam { int tot,last; map<int,int>ch[N]; int len[N],pre[N]; void init() { tot=last=1; } void extend(int c) { int np=++tot,p=last; len[np]=len[p]+1,last=np; for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np; if(!p) pre[np]=1; else { int q=ch[p][c]; if(len[q]==len[p]+1) pre[np]=q; else { int nq=++tot; len[nq]=len[p]+1; pre[nq]=pre[q],pre[q]=pre[np]=nq,ch[nq]=ch[q]; for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq; } } G[np].push_back(len[np]); seg::update(rt[np],1,total,len[np],len[np]); } void get_rank() { int i,j; for(i=1;i<=tot;++i) ++bu[len[i]]; for(i=1;i<=tot;++i) bu[i]+=bu[i-1]; for(i=1;i<=tot;++i) rk[bu[len[i]]--]=i; } }; int main() { // setIO("input"); int i,j,n; sam::init(); scanf("%d",&n); for(i=1;i<=n;++i) scanf("%d",&arr[i]); for(i=1;i<n;++i) arr[i]=arr[i+1]-arr[i]; total=n-1; for(i=1;i<=total;++i) sam::extend(arr[i]); sam::get_rank(); ll ans=0ll; ans=1ll*n*(n-1)/2; for(i=1;i<=sam::tot;++i) id[i]=i; for(i=sam::tot;i>=2;--i) { int u=rk[i]; int a=rk[i]; int b=sam::pre[a]; if(G[id[a]].size()>G[id[b]].size()) swap(a,b); // id[a] < id[b] for(j=0;j<G[id[a]].size();++j) { int x=G[id[a]][j]; ans+=(ll)sam::len[sam::pre[u]]*seg::query(rt[b],1,total,1,x-sam::len[sam::pre[u]]-1).sum1; ans+=(ll)sam::len[sam::pre[u]]*seg::query(rt[b],1,total,x+sam::len[sam::pre[u]]+1,total).sum1; seg::data tmp1=seg::query(rt[b],1,total,x-sam::len[sam::pre[u]],x-2); seg::data tmp2=seg::query(rt[b],1,total,x+2,x+sam::len[sam::pre[u]]); ans+=(ll)(x-1)*tmp1.sum1-tmp1.sum2; ans+=tmp2.sum2-(ll)(x+1)*tmp2.sum1; G[id[b]].push_back(x); } id[sam::pre[u]]=id[b]; rt[sam::pre[u]]=seg::merge(rt[a],rt[b]); } printf("%lld\n",ans); return 0; }