代码源每日一题Div1 103 子串的最大差 题解

代码源每日一题Div1 103 子串的最大差 题解

题目链接

简要题意

一个序列的最大差,定义为该序列最大值与最小值的差。

给定一个长度为 \(n\) 的数列 \(\{a_n\}\),求出该数列所有连续子串的最大差之和。

\(n\leq 5*10^5,0\leq a_i\leq 10^8\)

本题目原型为CF817D Imbalanced Array

题解

枚举子串肯定不行,双指针也只是算最大/最小而不是求和,所以我们只能算每个数的贡献。

我们记一个子序列的最大值为 \(f(l,r)\),最小值为 \(g(l,r)\),那么我们要求的值即为:

\[\sum\limits_{i=1}^n\sum\limits_{j=i}^n(f(l,r)-g(l,r))=\sum\limits_{i=1}^n\sum\limits_{j=i}^nf(l,r)-\sum\limits_{i=1}^n\sum\limits_{j=i}^ng(l,r) \]

这个等式意味着一个数列的 f 和 g 可以分开来单独计算,而不是合并起来一起计算。

分开计算后,我们分别使用单调栈来分别处理即可:对于某个元素 \(x\),直接找到左右边界,然后乘法原理算出所有区间的数量,最后逐渐累加即可。

同时,注意到这种计算必须得让所有元素各不相同,否则会导致一个区间被算了多次或者出现重复等情况,所以必须得进行一次离散化操作。

离散化操作+单调栈,总复杂度 \(O(n\log n)\)

#include <bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 1000010;
int n, a[N];
int f[N], g[N];
namespace DC {
    struct Node {
        int val, id;
        bool operator < (const Node &rhs) {
            return val < rhs.val;
        }
    } arr[N];
    void solve() {
        for (int i = 1; i <= n; ++i)
            arr[i] = (Node){a[i], i};
        sort(arr + 1, arr + n + 1);
        for (int i = 1; i <= n; ++i)
            a[arr[i].id] = i;
    }
    int get(int id) { return arr[id].val; }
};
LL solve1() {
    memset(f, 0, sizeof(f));
    memset(g, 0, sizeof(g));
    stack<int> s;
    for (int i = 1; i <= n; ++i) {
        while (!s.empty() && a[s.top()] < a[i]) {
            int x = s.top(); s.pop();
            f[x] = i;
        }
        s.push(i);
    }
    for (int i = 1; i <= n; ++i)
        if (f[i] == 0) f[i] = n + 1;
    while (!s.empty()) s.pop();
    for (int i = n; i >= 1; --i) {
        while (!s.empty() && a[s.top()] < a[i]) {
            int x = s.top(); s.pop();
            g[x] = i;
        }
        s.push(i);
    }
    LL res = 0;
    for (int i = 1; i <= n; ++i)
        res += 1LL * DC::get(a[i]) * (i - g[i]) * (f[i] - i);
    return res;
}
LL solve2() {
    memset(f, 0, sizeof(f));
    memset(g, 0, sizeof(g));
    stack<int> s;
    for (int i = 1; i <= n; ++i) {
        while (!s.empty() && a[s.top()] > a[i]) {
            int x = s.top(); s.pop();
            f[x] = i;
        }
        s.push(i);
    }
    for (int i = 1; i <= n; ++i)
        if (f[i] == 0) f[i] = n + 1;
    while (!s.empty()) s.pop();
    for (int i = n; i >= 1; --i) {
        while (!s.empty() && a[s.top()] > a[i]) {
            int x = s.top(); s.pop();
            g[x] = i;
        }
        s.push(i);
    }
    LL res = 0;
    for (int i = 1; i <= n; ++i)
        res += 1LL * DC::get(a[i]) * (i - g[i]) * (f[i] - i);
    return res;
}
int main()
{
    cin >> n;
    for (int i = 1; i <= n; ++i)
        cin >> a[i];
    DC::solve();
    cout << solve1() - solve2() << endl;
    return 0;
}
posted @ 2022-05-11 10:01  cyhforlight  阅读(53)  评论(0编辑  收藏  举报