CodeForces-1691D Max GEQ Sum
Max GEQ Sum
单调栈 + 线段树
这题非常值得
思路:
假设我们此时的最大值为 \(a_i\),那么我们考虑不等式是否成立时,就应该考虑一个最大的区间 \([l,r]\),在这里找有没有不符合不等式的情况出现
最大区间 \([l, r]\):\(a_{l - 1} > a_i\) 且 \(a_{r + 1} > a_i\)
如何找到不符合不等式的情况:
\(sum(l,r) > a_i\)
-> \(sum(l, i-1) + sum(i+1, r) > 0\)
因为我们只要在 \([l, r]\) 上找到一个任意的包含 \(a_i\) 的区间,使得其和大于 \(a_i\) 即可,因此我们只需要考虑其两边任意一个区间大于 0 即可
-> \(sum(l,i-1) > 0\) 或 \(sum(i+1,r) > 0\)
实现:
- 找到区间 \([l, r]\),即找到每个 \(a_i\) 左右两边第一个大于 \(a_i\) 的数
这里考虑使用单调栈,栈底到栈顶,从大到小维护,碰到比当前数字小的就弹出
- 找不等式不成立的情况
这里我们改成 \(sum(l,i) > a_i\) 或 \(sum(i,r) > a_i\) 来实现
我们考虑区间和的时候使用前缀和 \(S\) 维护,以 \(sum(l,i) > a_i\) 为例:
\(sum(j,i) = S_i - S_{j-1}\) \((l \le j \le i)\)
不难发现,在查找的区间中,\(S_i\) 是不变的,因此我们可以转化成为在 \([l-1, i-1]\) 中,找一个 \(S\) 的最小值
这个查找可以考虑用线段树来实现,同理对于右边的情况可以用查找最大值来实现
单调栈:\(O(n)\)
总体线段树查询:\(O(nlogn)\)
时间复杂度:\(O(nlogn)\)
挺可惜没想到线段树 + 前缀和维护,想半天一直卡在找的上面
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#include <queue>
#include <functional>
#include <map>
#include <set>
#include <cmath>
#include <cstring>
#include <deque>
#include <stack>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
const ll maxn = 2e5 + 10;
const ll inf = 1e17 + 10;
ll num[maxn], sum[maxn];
int pre_l[maxn], pre_r[maxn];
void st_solve(int n)
{
stack<int>st;
st.push(0);
num[0] = num[n+1] = inf;
for(int i=1; i<=n; i++)
{
while(num[st.top()] <= num[i]) st.pop();
pre_l[i] = st.top();
st.push(i);
}
while(st.size()) st.pop();
st.push(n+1);
for(int i=n; i>=1; i--)
{
while(num[st.top()] <= num[i]) st.pop();
pre_r[i] = st.top();
st.push(i);
}
num[0] = num[n+1] = 0;
}
struct node
{
ll maxx, minn;
}tr[maxn << 2];
void push_up(int now)
{
tr[now].maxx = max(tr[now << 1].maxx, tr[now << 1 | 1].maxx);
tr[now].minn = min(tr[now << 1].minn, tr[now << 1 | 1].minn);
}
void build(int now, int l, int r)
{
if(l == r)
{
tr[now].maxx = tr[now].minn = sum[l];
return;
}
int mid = l + r >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
push_up(now);
}
ll query_min(int now, int l, int r, int L, int R)
{
if(L <= l && r <= R)
return tr[now].minn;
int mid = l + r >> 1;
ll ans = inf;
if(L <= mid) ans = query_min(now << 1, l, mid, L, R);
if(R > mid) ans = min(ans, query_min(now << 1 | 1, mid + 1, r, L, R));
return ans;
}
ll query_max(int now, int l, int r, int L, int R)
{
if(L <= l && r <= R)
return tr[now].maxx;
int mid = l + r >> 1;
ll ans = -inf;
if(L <= mid) ans = query_max(now << 1, l, mid, L, R);
if(R > mid) ans = max(ans, query_max(now << 1 | 1, mid + 1, r, L, R));
return ans;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while(t--)
{
int n;
cin >> n;
for(int i=1; i<=n; i++)
{
cin >> num[i];
sum[i] = sum[i-1] + num[i];
}
st_solve(n);
build(1, 0, n);
int f = 1;
for(int i=1; i<=n && f; i++)
{
int l = pre_l[i];
int r = pre_r[i];
ll l_val = query_min(1, 0, n, l, i - 1);
ll r_val = query_max(1, 0, n, i, r - 1);
if(max(sum[i] - l_val, r_val - sum[i-1]) > num[i]) f = 0;
}
if(f) cout << "YES" << endl;
else cout << "NO" << endl;
}
return 0;
}