二叉搜索树「区间DP」

题目描述

\(n\)个结点,第\(i\)个结点的权值为\(i\)

你需要对它们进行一些操作并维护一些信息,因此,你需要对它们建立一棵二叉搜索树。在整个操作过程中,第\(i\)个点需要被操作 \(x_i\) 次,每次你需要从根结点一路走到第 \(i\) 个点,耗时为经过的结点数。最小化你的总耗时。

输入格式

第一行一个整数n,第二行n个整数x1~xn。
输出格式

一行一个整数表示答案。

样例

样例输入

5
8 2 1 4 3

样例输出

35

数据范围与提示

对于10%的数据,\(n \leq 10\)

对于40%的数据,\(n \leq 300\)

对于70%的数据,\(n \leq 2000\)

对于100%的数据,\(n \leq 5000\)\(1 \leq x_i \leq 10^9\)

提示:二叉搜索树或者是一棵空树,或者是具有下列性质的二叉树:若它的左子树不空,则左子树上所有结点的值均小于它的根结点的值;若它的右子树不空,则右子树上所有结点的值均大于它的根结点的值;它的左、右子树也分别为二叉搜索树。

简单解释

不要被题目给吓到,这题并不用BST。

因为区间内的节点的权值是单增且连续的。

所以我们考虑从区间内单独拎某一个节点出来作为根,他左边的点一定都会是他的左子树,右边的一定会是他的右子树。

从这里直接想到区间DP可能有点难,所以我们可以先从记忆化DFS的方向来考虑。

我们枚举根 \(k\),把整个区间分成左右两半,相当于是把 \(k\) 的左边和右边的节点的深度都增加了1,对于答案来说,答案会增加 \(sum[k - 1] - sum[l - 1] + sum[r] - sum[k + 1] + x[k]\),最后的 \(x[k]\) 是根自己。(\(sum\)\(x_i\)前缀和)

化简一下就是 \(sum[r] - sum[l - 1]\),是不是很清新。

然后我们就可以愉快地将一个大问题分成两个子问题,再继续愉快地DFS,看起来没毛病对不对。

然而问题是在DFS的过程中我们并不能知道前面的断点依次是多少,深度也就无从得知,自然回溯的时候会出问题。如果记录一下的话就成了纯粹的爆搜。

大概是这样解释的,具体细节咱也解释不太清 (毕竟考场上我连爆搜都没想到)

所以该怎么办呢,考虑一下:区间、断点,一定会有神犇 (因为为同机房就有一个) 能联想到区间DP的四边形不等式优化...于是这题就可以用区间DP来做了。

定义一下 \(f[l][r]\)\(l\)\(r\) 的区间的最小答案。
\(g[l][r]\) 为 区间 \([l,r]\) 的最优断点,不理解请自行学习四边形不等式优化。(只是这个人太菜不会讲而已)

于是结合上面关于DFS的思考,有了转移方程: \(f[l][r] = min(f[l][r], f[l][k - 1] + f[k + 1][r] + sum[r] - sum[l-1])\)

其中 \(k\) 为我们枚举的断点,根据四边形不等式可以判断最优点一定在 \(g[l][r-1]\)\(g[l+1][r]\) 之间。

然后我们就可以愉快地DP了。

然后当你愉快地打出区间DP的板子,发现T飞了。

再然后,这是为什么呢。因为我们一般的区间DP都是第一维枚举长度,第二维枚举左端点 \(l\),这样一般是没问题的,复杂度也是稳稳的 \(n^2\),但他就是T了。

这涉及到二维数组的存储和随机访问的效率问题,感性理解一下就是二维数组在内存里是一行一行来存储的,如果我们先枚举长度再枚举端点,那么每相邻两次的 \(l\)\(l\)\(r\)\(r\) 是不连续的,所以在枚举时就会在一行行内存里跳来跳去,用屁股想也知道这样会慢。

但是一般只要算法的瓶颈复杂度足够了,这样小常数不算什么。然而,有一种生物叫做毒瘤出题人...

所以就有了下面代码中的写法,\(l\) 倒序枚举, \(r\) 正序枚举,既可以保证正确性 (正确性应该不需要我证吧...),又可以保证在访问内存时 \(l\) 固定,\(r\) 也是连续的,这样就只会在一行当中访问,自然会快一些。

当然别的区间DP也是可以这么写的。

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5005;

char buf[1 << 20], *p1 = buf, *p2 = buf;
char getc() {
	if(p1 == p2) {
		p1 = buf, p2 = buf + fread(buf, 1, 1 << 20, stdin);
		if(p1 == p2) return EOF; 
	}
	return *p1++;
}
int read() {
	int s = 0, w = 1;
	char c = getc();
	while(c < '0' || c > '9') {if(c == '-') w = -1; c = getc();}
	while(c >= '0' && c <= '9') s = s * 10 + c - '0', c = getc();
	return s * w;
}

int n;
long long sum[maxn], f[maxn][maxn];
int g[maxn][maxn];

int main() {
	n = read();
	for(int i = 1; i <= n; i++) f[i][i] = read(), sum[i] = sum[i - 1] + f[i][i], g[i][i] = i;
	
	for(int l = n - 1; l >= 1; l--) {
		for(int r = l + 1; r <= n; r++) {
			f[l][r] = LLONG_MAX; //注意是LL,INT不够大
			for(int k = g[l][r - 1]; k <= g[l + 1][r]; k++) {
				if(f[l][r] > f[l][k - 1] + f[k + 1][r] + sum[r] - sum[l - 1]) {
					f[l][r] = f[l][k - 1] + f[k + 1][r] + sum[r] - sum[l - 1];
					g[l][r] = k;
				}
			}
		}
	}
	printf("%lld\n", f[1][n]);
	return 0;
}
posted @ 2020-09-09 10:41  zfio  阅读(184)  评论(0编辑  收藏  举报