D. Imbalanced Array

题目链接

D. Imbalanced Array

求所有连续子区间的最大值之和减所有连续子区间最小值之和

Input

The first line contains one integer \(n (1 ≤ n ≤ 10^6)\) — size of the array \(a\).

The second line contains \(n\) integers \(a_1, a_2... a_n (1 ≤ a_i ≤ 10^6)\) — elements of the array.

Output

Print one integer — the imbalance value of \(a\).

解题思路

单调栈

可以计算每个值作为最小值和最大值时的贡献,以最小值为例:即找左右两边第一个比其大的数,可利用单调栈实现,但由于会有重复计算,即一段区间内出现多个相同的值,这时可选择寻找第一个左边大于或等于和右边大于的数

  • 时间复杂度:\(O(n)\)

代码

// %%%Skyqwq
#include <bits/stdc++.h>
 
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
 
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
 
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
 
template <typename T> void inline read(T &x) {
    int f = 1; x = 0; char s = getchar();
    while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
    while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
    x *= f;
}

template <typename T>
inline void print(T x)
{
    if(x<0)
    {
        putchar('-');
        x=-x;
    }
    if(x>9)
        print(x/10);
    putchar(x%10+'0');
}

const int N=5e5+5;
stack<LL> stk;
LL sl[N],sr[N],bl[N],br[N],n,a[N];
map<pair<int,PLL>,bool> mp1,mp2;
int main()
{
    read(n);
    for(int i=1;i<=n;i++)read(a[i]);
    for(int i=1;i<=n;i++)
    {
    	while(stk.size()&&a[stk.top()]>a[i])stk.pop();
    	if(stk.size())sl[i]=stk.top()+1;
    	else
    		sl[i]=1;
    	stk.push(i);
    }
    while(stk.size())stk.pop();
    for(int i=1;i<=n;i++)
    {
    	while(stk.size()&&a[stk.top()]<a[i])stk.pop();
    	if(stk.size())bl[i]=stk.top()+1;
    	else
    		bl[i]=1;
    	stk.push(i);
    }
    while(stk.size())stk.pop();
    for(int i=n;i;i--)
    {
    	while(stk.size()&&a[stk.top()]>=a[i])stk.pop();
    	if(stk.size())sr[i]=stk.top()-1;
    	else
    		sr[i]=n;
    	stk.push(i);
    }
    while(stk.size())stk.pop();
    for(int i=n;i;i--)
    {
    	while(stk.size()&&a[stk.top()]<=a[i])stk.pop();
    	if(stk.size())br[i]=stk.top()-1;
    	else
    		br[i]=n;
    	stk.push(i);
    }
    LL res=0;
    for(int i=1;i<=n;i++)
    	res+=a[i]*((i-bl[i]+1)*(br[i]-i+1)-(i-sl[i]+1)*(sr[i]-i+1));
    print(res);
    return 0;
}
posted @ 2022-02-27 13:53  zyy2001  阅读(66)  评论(0编辑  收藏  举报