AtCoder Beginner Contest 234 G - Divide a Sequence

传送门

题目描述

把一个长度为 \(N\) 的数组 \(A\), 分为几个连续的子序列 \(B_1, B_2, ... , B_k\),有 \(2^{N-1}\) 种划分方式

先给出数组 \(A\) 求出所有划分方式的价值之和,并对 \(998244353\) 取模.

对于一种划分方式 \(B_1, B_2, ... , B_k\) 的价值为 \({\textstyle \prod_{i=1}^{k}} (\max(B_i) - \min(B_i))\)

对于一个子序列\(B_i = (B_{i,1}, B_{i,1}, ..., B_{i,j} )\) ,其中最大的元素为 \(\max(B_i)\),最小的元素为 \(\min(B_i)\)

思路

首先我们可以想到一个 \(n^2\)\(DP\)
定义 \(f_i\) 为前 \(i\) 个数字的所有划分方式的价值之和
那么可以得到转移方程 \(f_i = \sum_{j=1}^{i-1} f_j \times (\max(a_{j+1},.., a_{i}) - \min(a_{j+1},.., a_{i}))\)
(不包含 \(f_{i-1}\) 的原因是 单个数字的价值为 \(0\))
通过倒序遍历可以维护出 \(\max, \min\) 因此复杂度为 \(O(n^2)\)

这么写肯定是会 \(TLE\)

因此我们考虑如何去优化

我们可以把 \(\max\)\(\min\) 的贡献单独去考虑 (这也是常用的一个套路)

首先分析 \(\max\)

转移方程 \(f_i = \sum_{j=1}^{i-1} f_j \times (\max(a_{j+1},.., a_{i}) - \min(a_{j+1},.., a_{i}))\)
可以转换为 \(f_i = \sum_{j=1}^{i-1} f_j \times \max(a_{j+1},.., a_{i}) - \sum_{j=1}^{i-1} f_j \times \min(a_{j+1},.., a_{i})\)

我们可以用 \(m_i\) 代表当前的 \(\max\) 的价值和
那么 \(m_i\)\(m_{i-1}\) 是否存在什么联系呢? 答案是存在的

我们考虑的是最大值对答案的贡献
对于 \(f_i\) 来说,产生贡献的最大值一定是单调上升的一些数字,因为我们 \(f_i\) 进行转移的时候是根据最后一段子序列进行分类的
\(f_i\) 是由 \(f_{i-2}, f_{i-3}, ..., f_{1}, f_{0}\) 转移过来的,我们后缀最大值的贡献一定是在一个连续的区间,并且是单调递增的
那么我们就可以利用一个单调栈,每次弹出栈顶的时候减去栈顶的所有贡献,然后在最后加上当前位置的贡献即可

最小值同理

CODE
/********************
Author:  Nanfeng1997
Contest: AtCoder - AtCoder Beginner Contest 234
URL:     https://atcoder.jp/contests/abc234/tasks/abc234_g
When:    2022-03-16 10:28:26

Memory:  1024MB
Time:    2000ms
********************/
#include  <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MOD = 998244353;

inline int mod(int x) {return x >= MOD ? x - MOD : x;}

inline int ksm(int a, int b) {
  int ret = 1; a = mod(a);
  for(; b; b >>= 1, a = 1LL * a * a % MOD) if(b & 1) ret = 1LL * ret * a % MOD;
  return ret;
}

template<int MOD> 
struct modint {
  int x;
  modint() {x = 0; }
  modint(int y) {x = y;}
  inline modint inv() const { return modint{ksm(x, MOD - 2)}; }
  explicit inline operator int() { return x; }
  friend inline modint operator + (const modint &a, const modint& b) { return modint(mod(a.x + b.x)); }
  friend inline modint operator - (const modint &a, const modint& b) { return modint(mod(a.x - b.x + MOD)); }
  friend inline modint operator * (const modint &a, const modint& b) { return modint(1ll * a.x * b.x % MOD); }
  friend inline modint operator - (const modint &a) { return modint(mod(MOD - a.x)); }
  friend inline modint& operator += (modint &a, const modint& b) { return a = a + b; }
  friend inline modint& operator -= (modint &a, const modint& b) { return a = a - b; }
  friend inline modint& operator *= (modint &a, const modint& b) { return a = a * b; }
  inline int operator == (const modint &b) { return x == b.x; }
  inline int operator != (const modint &b) { return x != b.x; }
  inline int operator < (const modint &a) { return x < a.x; }
  inline int operator <= (const modint &a) { return x <= a.x; }
  inline int operator > (const modint &a) { return x > a.x; }
  inline int operator >= (const modint &a) { return x >= a.x; }
};

typedef modint<MOD> mint;

inline mint ksm(mint a, int b) {
  mint ret = 1; 
  for(; b; b >>= 1, a = a * a ) if(b & 1) ret = ret * a ;
  return ret;
}

const int N = 3e5 + 10;

int n;
int a[N], s1[N], s2[N];
mint dp[N], tr[N];

void add(int a, mint k) {
  while(a <= n) tr[a] += k, a += a & -a;
}

mint query(int x) {
  mint ret = 0; if(x >= 0) ret += 1; //树状数组的边界是1, 因此我们手动加上0处的贡献
  while(x > 0) ret += tr[x], x -= x & -x;
  return ret;
}

mint ask(int l, int r) {return query(r) - query(l - 1); }

void solve() {
  scanf("%d", &n);
  for(int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
  int t1 = 0, t2 = 0;
  mint mx = 0, mi = 0; 

  //dp[0] = 1, 因为我们进行转移的时候是乘法,乘法的幺元是 1
  //mx 是最大值的贡献 
  //mi 是最小值的贡献
  
  for(int i = 1; i <= n; i ++ ) {

    while(t1 && a[s1[t1]] <= a[i]) {
      int t = s1[t1 --];
      mint ret = ask(s1[t1], t - 1);
      mx = mx - ret * a[t];
    }

    mx = mx + ask(s1[t1], i) * a[i];
    s1[++ t1] = i;
    
    while(t2 && a[s2[t2]] >= a[i]) {
      int t = s2[t2 --];
      mint ret = ask(s2[t2], t - 1);
      mi = mi - ret * a[t];
    }

    mi = mi + ask(s2[t2], i) * a[i];
    s2[++ t2] = i;

    dp[i] = mx - mi;
    add(i, dp[i]);
  }
  printf("%d", (int)dp[n]);

}

int main() {
  // ios::sync_with_stdio(false);
  // cin.tie(nullptr);
  int T = 1; //cin >> T;
  while(T --) solve();

  return 0;
}

posted @ 2022-03-16 19:43  ccz9729  阅读(57)  评论(0编辑  收藏  举报