【CodeChef】Limit of MEX(二分、ST表、组合数学)

题目大意:

计算\(\sum_{L=1}^{N}\sum_{R=L}^{N}f(A_L,...,A_R)\),其中\(f(A_1,A_2,...,A_N)=\max(A_1,A_2,...,A_N)-count(A_1,A_2,...,A_N)+1\)\(count\)函数的值为参数中不同元素的个数。


考虑计算\(\sum_{L=1}^{N}\sum_{R=L}^{N}max(A_1,A_2,...,A_N)\)

对于任意\(1\le i\le n\),我们找出最小的\(l_i\)最大的\(r_i\),使得\(\max(A_{l_i},...,A_{i-1})\le A_i\)\(\max(A_{i+1},...,A_{r_i})<A_i\)。这两个值可以用二分得到。验证二分的值需要用到\(A_1,...,A_N\)连续子区间最大值,这可以用ST表解决。

求出\(l_i\)\(r_i\)后,我们可以得到\(A_i\)对答案的贡献为\(A_i\cdot(i-l_i+1)\cdot(r_i-i+1)\),所以答案为\(\sum_{i=1}^{N}A_i\cdot(i-l_i+1)\cdot(r_i-i+1)\)


考虑计算\(\sum_{L=1}^{N}\sum_{R=L}^{N}count(A_1,A_2,...,A_N)\)

对于任意\(x\in {A_1,A_2,...,A_N}\),我们考虑所有值为\(x\)的元素对答案的贡献,即\(A_1,..A_N\)中有多少个连续子区间包含等于\(x\)元素。

直接计算包含等于\(x\)元素的连续子区间个数较为困难。所以可以算出不包含等于\(x\)元素的连续子区间个数,然后将所有连续子区间个数减去这个值,即可得到包含等于\(x\)元素的连续子区间个数,即当前\(x\)对答案的贡献。

将所有的贡献相加即是答案。

#include<bits/stdc++.h>
#define pt printf(">>>")
#define mid (((l)+(r))/2)
using namespace std;
typedef long long ll;
const ll N=1e6+10,inf=1e18+10,mod=1e9+7;
ll n,a[N],st[N][30],lg[N];
ll query(ll l,ll r){return max(st[l][lg[r-l+1]],st[r-(1<<lg[r-l+1])+1][lg[r-l+1]]);}
ll f(ll x){
	ll ret1=x,ret2=x,l,r;
	l=1,r=x-1;
	while(l<=r)
		if(query(mid,x-1)<=a[x])ret1=mid,r=mid-1;
		else l=mid+1;
	l=x+1,r=n;
	while(l<=r)
		if(query(x+1,mid)<a[x])ret2=mid,l=mid+1;
		else r=mid-1;
	return (x-ret1+1)*(ret2-x+1);
}
ll work1(){
	ll ret=0;
	for(ll i=1;i<=n;i++)st[i][0]=a[i];
	for(ll i=2;i<=n;i++)lg[i]=lg[i>>1]+1;
	for(ll i=1;i<=lg[n];i++)
		for(ll j=1;j+(1<<i)-1<=n;j++)
			st[j][i]=max(st[j][i-1],st[j+(1<<(i-1))][i-1]);
	for(ll i=1;i<=n;i++)ret+=a[i]*f(i);
	return ret;
}
ll cal(ll x){return (x+1)*x/2;}
ll work2(){
	map<ll,vector<ll> > pos;
	ll ret=0;
	for(ll i=1;i<=n;i++)pos[a[i]].push_back(i);
	for(auto c:pos){
		ret+=cal(n);
		for(ll i=0;i<c.second.size();i++){
			ll now=c.second[i];
			if(i==0)ret-=cal(c.second[i]-1);
			else ret-=cal(c.second[i]-c.second[i-1]-1);
			if(i+1==c.second.size())ret-=cal(n-c.second[i]);
		}
	}
	return ret;
}
int main(){
	int T=1;
	cin >> T;
	while(T--){
		cin >> n;
		for(ll i=1;i<=n;i++)cin >> a[i];
		cout << work1()-work2()+(1+n)*n/2 << endl;
	}
	return 0;
}
posted @ 2024-05-23 15:01  Alric  阅读(8)  评论(0编辑  收藏  举报