CF Round #783 D - Optimal Partition

D - Optimal Partition

线段树 + dp

设 s[i] 为前缀和数组,f[i] 为考虑前 i 个数的最大答案。当枚举到第 i 个数时,有状态转移方程如下

\[1. \;s[i] -s[j]>0,\;即\;s[j]<s[i]\;(j<i)\\f[i]=max(f[j]+i-j)\;即\;f[i]-i=max(f[j]-j)\\2. \;s[i] -s[j]==0,\;即\;s[j==s[i]\;(j<i)\\f[i]=max(f[j])\\\\3. \;s[i] -s[j]<0,\;即\;s[j]>s[i]\;(j<i)\\f[i]=max(f[j]-i+j)\;即\;f[i]+i=max(f[j]+j)\\ \]

可用 s[i] 作为下标(离散化),建立三颗线段树,分别维护 f[i] - i, f[i], f[i] + i 的最大值

对于第一种情况,设 s[i] 离散化后的下标为 idx, 找到 1 ~ idx 中 f[i] - i 的最大值后 + i 即使当前答案

第二、三种同理

注意状态转移的起点是 s[0] = 0, f[0] = 0, 离散化时要考虑 s[0], 并且一开始将 s[0] 的答案 f[0] 插入到线段树中

线段树包含 0 ~ n 共 n + 1 个点,要开 n + 1 的空间

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 5e5 + 10;
const ll INF = 4e18;
int n;
ll a[N], s[N], f[N];
vector<ll> alls;
struct Node
{
	int l, r;
	ll maxn;
}tr[3][N<<2];

void pushup(int id, int u)
{
	tr[id][u].maxn = max(tr[id][u<<1].maxn, tr[id][u<<1|1].maxn);
}

void build(int id, int u, int l, int r)
{
	tr[id][u] = {l, r};
	if (l == r)
	{
		tr[id][u].maxn = -INF;
		return;
	}
	int mid = l + r >> 1;
	build(id, u << 1, l, mid);
	build(id, u << 1 | 1, mid + 1, r);
	pushup(id, u);
}

void modify(int id, int u, int idx, ll k)
{
	Node &root = tr[id][u];
	if (root.l == idx && root.r == idx)
	{
		root.maxn = max(root.maxn, k);
		return;
	}
	int mid = root.l + root.r >> 1;
	if (idx <= mid)
		modify(id, u << 1, idx, k);
	else
		modify(id, u << 1 | 1, idx, k);
	pushup(id, u);
}

ll query(int id, int u, int l, int r)
{
	Node &root = tr[id][u];
	if (root.l >= l && root.r <= r)
		return root.maxn;
	int mid = root.l + root.r >> 1;
	ll v = -INF;
	if (l <= mid)
		v = query(id, u << 1, l, r);
	if (r > mid)
		v = max(v, query(id, u << 1 | 1, l, r));
	return v;
}

int find(ll x)
{
	return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1;
}

int main()
{
	ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
	int T;
	cin >> T;
	while(T--)
	{
		cin >> n;
		alls.clear();
		alls.push_back(0);
		for (int i = 1; i <= n; i++)
		{
			cin >> a[i];
			s[i] = s[i-1] + a[i];
			alls.push_back(s[i]);
		}
		
		sort(alls.begin(), alls.end());
		alls.erase(unique(alls.begin(), alls.end()), alls.end());
		
		for (int i = 0; i < 3; i++)
		{
			build(i, 1, 1, n + 1);
			modify(i, 1, find(s[0]), 0);
		}

		for (int i = 1; i <= n; i++)
		{
			int idx = find(s[i]);
			ll t1 = query(0, 1, 1, idx - 1);
			ll t2 = query(1, 1, idx, idx);
			ll t3 = query(2, 1, idx + 1, n + 1);
			f[i] = max({t1 + i, t2, t3 - i});
			modify(0, 1, idx, f[i] - i);
			modify(1, 1, idx, f[i]);
			modify(2, 1, idx, f[i] + i);
		}
		
		cout << f[n] << endl;
	}
	return 0;
}

posted @ 2022-05-12 20:43  hzy0227  阅读(13)  评论(0编辑  收藏  举报