题解 WD与数列

P5161 WD与数列

可以想到原条件是一个差分形式,所以我们对原数组差分。然后发现答案其实就是 \(\sum_{i<j} \min(lcp(i+1,j+1)+1,j-i)\)

这个东西先跑 SA,然后建笛卡尔树。

考虑对于一个区间,其值为 \(x\)。那么相当于是求 \(\sum_{l\in S,r\in T} \min(|sa_{l}-sa_{r}|,x)\)

笛卡尔树的一个性质是:较小的区间之和不超过 \(O(n\log n)\)。所以直接暴力枚举较小区间,假设为 \(l\),那么对于右边相当于是求区间内 \(\le\) 某个数的个数,我们显然可以把询问按照 \(x\) 离线下来,从小到大做。或者直接开主席树,都是 \(O(n\log^2n)\) 的。这题略微卡常,注意实现

  • 好像存在 \(O(n\log n)\) 的做法。
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
const int maxt=1e7, maxn=3e5+5;

int n,m;
int sa[maxn], rk[maxn], cnt[maxn], tp[maxn], height[maxn], lg[maxn], w[maxn][20];
int s[maxn], a[maxn];

void basesort(){
    memset(cnt,0,sizeof(cnt));
    for(int i=1;i<=n;i++) cnt[rk[i]]++;
    for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1]; 
    for(int i=n;i>=1;i--) sa[cnt[rk[tp[i]]]--]=tp[i];
    return ;
}

void SuffixSort() {
    for(int i=1;i<=n;i++) rk[i]=s[i],tp[i]=i;
    basesort();
    for(int w=1,p=0;p<n;m=p,w<<=1) {
        p=0;
        for(int i=1;i<=w;i++) tp[++p]=n-w+i;
        for(int i=1;i<=n;i++) if(sa[i]>w) tp[++p]=sa[i]-w;
        basesort();
        for(int i=1;i<=n;++i) swap(tp[i],rk[i]);
        rk[sa[1]]=1; 
		p=1;
        for(int i=2;i<=n;i++) {
        	if(tp[sa[i-1]]==tp[sa[i]]&&tp[sa[i-1]+w]==tp[sa[i]+w]) rk[sa[i]]=p;
            else rk[sa[i]]=++p;
        }
    }
    return ;
}

int lcp(int l,int r) {
    int k=lg[r-l];
    return min(w[l+1][k],w[r-(1<<k)+1][k]);
}

void init() {
    cin>>n;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=1;i<=n;i++) s[i]=a[i]-a[i-1];
    for(int i=1;i<=n;i++) a[i]=s[i];
    sort(a+1,a+n+1); m=unique(a+1,a+n+1)-a-1;
    for(int i=1;i<=n;i++) s[i]=lower_bound(a+1,a+m+1,s[i])-a;
    SuffixSort();
    int k=0;
    for(int i=1;i<=n;i++) {
        if(k) --k;
        int j=sa[rk[i]-1];
        while(i+k<=n&&s[i+k]==s[j+k]) ++k;
        height[rk[i]]=k;
    }
    lg[1]=0;
    for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
    for(int i=1;i<=n;i++) w[i][0]=height[i];
    for(int j=1;(1<<j)<=n;j++) {
        for(int i=1;i+(1<<j)-1<=n;i++) {
            w[i][j]=min(w[i][j-1],w[i+(1<<j-1)][j-1]);
        }
    }
    return ;
}

int tl[maxn], tr[maxn];

int tot;
int ls[maxt], rs[maxt], rt[maxn];
ll sum[maxt], sum2[maxt];

void build(int &x,int l,int r) {
    x=++tot;
    if(l==r) return ;
    int mid=l+r>>1;
    build(ls[x],l,mid);
    build(rs[x],mid+1,r);
    return ;
}

void updata(int &x,int x2,int p,int c,int l,int r) {
    x=++tot;
    sum[x]=sum[x2]+p*c, sum2[x]=sum2[x2]+c, ls[x]=ls[x2], rs[x]=rs[x2];
    if(l==r) return ;
    int mid=l+r>>1;
    if(p<=mid) updata(ls[x],ls[x2],p,c,l,mid);
    else updata(rs[x],rs[x2],p,c,mid+1,r);
    return ;
}

pair<ll,ll> operator + (pair<ll,ll> x,pair<ll,ll> y) {
    return {x.first+y.first,x.second+y.second};
}

pair<ll,ll> query(int x1,int x2,int L,int R,int l,int r) {
    if(L>R) return {0,0};
    if(L<=l&&r<=R) return {sum[x2]-sum[x1],sum2[x2]-sum2[x1]};
    int mid=l+r>>1;
    pair<ll,ll> res={0,0};
    if(L<=mid) res=res+query(ls[x1],ls[x2],L,R,l,mid);
    if(mid<R) res=res+query(rs[x1],rs[x2],L,R,mid+1,r);
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    init();
    vector<int> vec;
    for(int i=2;i<=n;i++) {
        while(vec.size()&&height[vec.back()]>=height[i]) {
            tr[vec.back()]=i;
            vec.pop_back();
        }
        if(vec.size()) tl[i]=vec.back();
        else tl[i]=1;
        vec.push_back(i);
    }
    while(vec.size()) tr[vec.back()]=n+1, vec.pop_back();
    build(rt[0],1,n);
    for(int i=1;i<=n;i++) updata(rt[i],rt[i-1],sa[i],(sa[i]!=1),1,n);
    ll ans=0;
    for(int i=2;i<=n;i++) {
        int l=tl[i], r=tr[i]; --r;
        int x=height[i]+1;
        if(i-l<=r-i+1) {
            for(int j=l;j<i;j++) {
                int v=sa[j];
                if(v==1) continue;
                ll sum=r-i+1-(i<=rk[1]&&rk[1]<=r);
                pair<ll,ll> res=query(rt[i-1],rt[r],max(v-x,1),v,1,n);
                ans+=res.second*v-res.first, sum-=res.second;
                res=query(rt[i-1],rt[r],v+1,min(v+x,n),1,n);
                ans+=res.first-res.second*v, sum-=res.second;
                ans+=sum*x;
            }
        }else {
            for(int j=i;j<=r;j++) {
                int v=sa[j];
                if(v==1) continue;
                ll sum=i-l-(l<=rk[1]&&rk[1]<i);
                pair<ll,ll> res=query(rt[l-1],rt[i-1],max(v-x,1),v,1,n);
                ans+=res.second*v-res.first, sum-=res.second;
                res=query(rt[l-1],rt[i-1],v+1,min(v+x,n),1,n);
                ans+=res.first-res.second*v, sum-=res.second;
                ans+=sum*x;
            }
        }
    }
    cout<<ans+n-1<<'\n';
    return 0;
}
posted @ 2024-01-19 15:01  OIshima  阅读(8)  评论(0编辑  收藏  举报