AtCoder Beginner Contest 234 G - Divide a Sequence
题目描述
把一个长度为 \(N\) 的数组 \(A\), 分为几个连续的子序列 \(B_1, B_2, ... , B_k\),有 \(2^{N-1}\) 种划分方式
先给出数组 \(A\) 求出所有划分方式的价值之和,并对 \(998244353\) 取模.
对于一种划分方式 \(B_1, B_2, ... , B_k\) 的价值为 \({\textstyle \prod_{i=1}^{k}} (\max(B_i) - \min(B_i))\)
对于一个子序列\(B_i = (B_{i,1}, B_{i,1}, ..., B_{i,j} )\) ,其中最大的元素为 \(\max(B_i)\),最小的元素为 \(\min(B_i)\)
思路
首先我们可以想到一个 \(n^2\) 的 \(DP\)
定义 \(f_i\) 为前 \(i\) 个数字的所有划分方式的价值之和
那么可以得到转移方程 \(f_i = \sum_{j=1}^{i-1} f_j \times (\max(a_{j+1},.., a_{i}) - \min(a_{j+1},.., a_{i}))\)
(不包含 \(f_{i-1}\) 的原因是 单个数字的价值为 \(0\))
通过倒序遍历可以维护出 \(\max, \min\) 因此复杂度为 \(O(n^2)\)
这么写肯定是会 \(TLE\) 的
因此我们考虑如何去优化
我们可以把 \(\max\) 和 \(\min\) 的贡献单独去考虑 (这也是常用的一个套路)
首先分析 \(\max\)
转移方程 \(f_i = \sum_{j=1}^{i-1} f_j \times (\max(a_{j+1},.., a_{i}) - \min(a_{j+1},.., a_{i}))\)
可以转换为 \(f_i = \sum_{j=1}^{i-1} f_j \times \max(a_{j+1},.., a_{i}) - \sum_{j=1}^{i-1} f_j \times \min(a_{j+1},.., a_{i})\)
我们可以用 \(m_i\) 代表当前的 \(\max\) 的价值和
那么 \(m_i\) 和 \(m_{i-1}\) 是否存在什么联系呢? 答案是存在的
我们考虑的是最大值对答案的贡献
对于 \(f_i\) 来说,产生贡献的最大值一定是单调上升的一些数字,因为我们 \(f_i\) 进行转移的时候是根据最后一段子序列进行分类的
\(f_i\) 是由 \(f_{i-2}, f_{i-3}, ..., f_{1}, f_{0}\) 转移过来的,我们后缀最大值的贡献一定是在一个连续的区间,并且是单调递增的
那么我们就可以利用一个单调栈,每次弹出栈顶的时候减去栈顶的所有贡献,然后在最后加上当前位置的贡献即可
最小值同理
CODE
/********************
Author: Nanfeng1997
Contest: AtCoder - AtCoder Beginner Contest 234
URL: https://atcoder.jp/contests/abc234/tasks/abc234_g
When: 2022-03-16 10:28:26
Memory: 1024MB
Time: 2000ms
********************/
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MOD = 998244353;
inline int mod(int x) {return x >= MOD ? x - MOD : x;}
inline int ksm(int a, int b) {
int ret = 1; a = mod(a);
for(; b; b >>= 1, a = 1LL * a * a % MOD) if(b & 1) ret = 1LL * ret * a % MOD;
return ret;
}
template<int MOD>
struct modint {
int x;
modint() {x = 0; }
modint(int y) {x = y;}
inline modint inv() const { return modint{ksm(x, MOD - 2)}; }
explicit inline operator int() { return x; }
friend inline modint operator + (const modint &a, const modint& b) { return modint(mod(a.x + b.x)); }
friend inline modint operator - (const modint &a, const modint& b) { return modint(mod(a.x - b.x + MOD)); }
friend inline modint operator * (const modint &a, const modint& b) { return modint(1ll * a.x * b.x % MOD); }
friend inline modint operator - (const modint &a) { return modint(mod(MOD - a.x)); }
friend inline modint& operator += (modint &a, const modint& b) { return a = a + b; }
friend inline modint& operator -= (modint &a, const modint& b) { return a = a - b; }
friend inline modint& operator *= (modint &a, const modint& b) { return a = a * b; }
inline int operator == (const modint &b) { return x == b.x; }
inline int operator != (const modint &b) { return x != b.x; }
inline int operator < (const modint &a) { return x < a.x; }
inline int operator <= (const modint &a) { return x <= a.x; }
inline int operator > (const modint &a) { return x > a.x; }
inline int operator >= (const modint &a) { return x >= a.x; }
};
typedef modint<MOD> mint;
inline mint ksm(mint a, int b) {
mint ret = 1;
for(; b; b >>= 1, a = a * a ) if(b & 1) ret = ret * a ;
return ret;
}
const int N = 3e5 + 10;
int n;
int a[N], s1[N], s2[N];
mint dp[N], tr[N];
void add(int a, mint k) {
while(a <= n) tr[a] += k, a += a & -a;
}
mint query(int x) {
mint ret = 0; if(x >= 0) ret += 1; //树状数组的边界是1, 因此我们手动加上0处的贡献
while(x > 0) ret += tr[x], x -= x & -x;
return ret;
}
mint ask(int l, int r) {return query(r) - query(l - 1); }
void solve() {
scanf("%d", &n);
for(int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
int t1 = 0, t2 = 0;
mint mx = 0, mi = 0;
//dp[0] = 1, 因为我们进行转移的时候是乘法,乘法的幺元是 1
//mx 是最大值的贡献
//mi 是最小值的贡献
for(int i = 1; i <= n; i ++ ) {
while(t1 && a[s1[t1]] <= a[i]) {
int t = s1[t1 --];
mint ret = ask(s1[t1], t - 1);
mx = mx - ret * a[t];
}
mx = mx + ask(s1[t1], i) * a[i];
s1[++ t1] = i;
while(t2 && a[s2[t2]] >= a[i]) {
int t = s2[t2 --];
mint ret = ask(s2[t2], t - 1);
mi = mi - ret * a[t];
}
mi = mi + ask(s2[t2], i) * a[i];
s2[++ t2] = i;
dp[i] = mx - mi;
add(i, dp[i]);
}
printf("%d", (int)dp[n]);
}
int main() {
// ios::sync_with_stdio(false);
// cin.tie(nullptr);
int T = 1; //cin >> T;
while(T --) solve();
return 0;
}