Norma
写在前面:序列分治是一种比较常见的技巧,将问题划分中点,每次只考虑计算跨区间中点算贡献,然后递归左右区间分开统计答案即可。
Analysis
本题不同的区间均能对答案产生贡献,我们考虑如何跨中点算贡献。与常见跨中点算贡献的套路相同,我们将跨中点右区间划分为三段:最值靠左,最值左右各异,最值靠右。再根据 \(cost_{i,j}\) 的定义来看,答案显然与区间长度有关,因此我们分类进行讨论。
最值都在左区间:
每段区间的最值都在左区间,直接乘起来就好。但是右区间的长度同样会造成贡献,怎么办?观察到此时的右区间长度产生贡献是一个等差数列,直接“高斯求和”可解。
最值左右区间各一个:
此时的最值不确定究竟是谁在左(或在右),因此分讨:
-
最大值在左区间:
此时右区间的最小值不确定,因此预处理出右区间最小值前缀和,这样实现了维护最大值乘最小值,但还短一个关键的系数:长度。长度不方便预处理,因为在中点靠左仍有长度会产生贡献,于是我们考虑拆开原式,具体来说:
\[\begin{align*} cost_{i,j}&=\sum maxn_l\times minn_r\times len_{i\sim j}\\ &=maxn_l\times\sum minn_r\times (len_{i\sim mid}+len_{mid+1\sim j})\\ &=maxn_l\times(\sum minn_r\times len_{i\sim mid}+\sum minn_r\times len_{mid+1\sim j})\\ &=maxn_l\times(len_{i\sim mid}\times \sum minn_r+\sum minn_r\times len_{mid+1\sim j})\\ \end{align*} \]我们将 \(\sum minn_r\) 与 \(\sum minn_r\times len_{mid+1\sim j}\) 分别前缀和维护,\(\Theta(1)\) 可求。
-
最大值在右区间:
与第一种情况大同小异,不再赘述。
最值都在右区间:
此时最值全部在右区间,长度这个系数仍然不好维护,我们采取与2中相同的思路:
\[\begin{align*}
cost_{i,j}&=\sum maxn_r\times minn_r\times len_{i\sim j}\\
&=\sum maxn_r\times minn_r\times (len_{i\sim mid}+len_{mid+1\sim j})\\
&=len_{i\sim mid} \times \sum maxn_r\times minn_r+\sum maxn_r\times minn_r\times len_{mid+1\sim j}
\end{align*}\]
此时我们同样维护 \(\sum maxn_r\times minn_r\) 与 \(\sum maxn_r\times minn_r\times len_{mid+1\sim j}\) 的前缀和,\(\Theta(1)\) 求解。
Code
总体来说维护6种不同的系数前缀和,每次递归算跨中点贡献,双指针扫描跨中点区间即可。
点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e5+10,p=1e9;
int n,ans(0),inv2;
int a[N];
int sum1[N],sum2[N],pre1[N],pre2[N],sum3[N],pre3[N];
void solve(int l,int r){
if(l==r) return ans=(ans+a[l]*a[l]%p)%p,void();
int mid((l+r)>>1);
int maxr=0,minr=1e9;
pre1[mid]=pre2[mid]=pre3[mid]=sum1[mid]=sum2[mid]=sum3[mid]=0;
for(int i=mid+1;i<=r;++i){
maxr=max(maxr,a[i]);
minr=min(minr,a[i]);
sum1[i]=(sum1[i-1]+minr)%p;
sum2[i]=(sum2[i-1]+maxr)%p;
sum3[i]=(sum3[i-1]+minr*maxr%p)%p;
pre1[i]=(pre1[i-1]+minr*(i-mid)%p)%p;
pre2[i]=(pre2[i-1]+maxr*(i-mid)%p)%p;
pre3[i]=(pre3[i-1]+maxr*minr%p*(i-mid)%p)%p;
}
int pos1(mid+1),pos2(mid+1);
int maxnl(0),minnl(1e9);
for(int i=mid;i>=l;--i){
maxnl=max(maxnl,a[i]);
minnl=min(minnl,a[i]);
while(pos1<=r&&a[pos1]<=maxnl) pos1++;pos1--;
while(pos2<=r&&a[pos2]>=minnl) pos2++;pos2--;
int w1(min(pos1,pos2)),w2(max(pos1,pos2));
ans=(ans+maxnl*minnl%p*((w1+mid-2*i+3)*(w1-mid)/2%p)%p)%p;
if(pos1<pos2) ans=(ans+minnl*((pre2[w2]-pre2[w1]+p)%p+(mid-i+1)*(sum2[w2]-sum2[w1]+p)%p)%p)%p;
if(pos2<pos1) ans=(ans+maxnl*((pre1[w2]-pre1[w1]+p)%p+(mid-i+1)*(sum1[w2]-sum1[w1]+p)%p)%p)%p;
ans=(ans+(pre3[r]-pre3[w2]+p)%p+(mid-i+1)*(sum3[r]-sum3[w2]+p)%p)%p;
}
solve(l,mid),solve(mid+1,r);
}
signed main(){
scanf("%lld",&n);
for(int i=1;i<=n;++i) scanf("%lld",a+i);
solve(1,n);
printf("%lld",ans);
}