三种方法求解最大子区间和:DP、前缀和、分治
题目
洛谷:P1115 最大子段和
LeetCode:最大子序和
给出一个长度为 \(n\) 的序列 \(a\),选出其中连续且非空的一段使得这段和最大。
挺经典的一道题目,下面分别介绍 \(O(n)\) 的 DP 做法、前缀和做法,以及 \(O(n\log n)\) 的分治做法。
DP 做法
用 \(d_i\) 表示结尾为位置 \(i\) 的最大区间和,则有
问题的答案即为 \(\max\{d_i \mid i\in[1,n]\}\)。
编写代码时不需要开 \(d\) 数组,用变量 last_d
记录 \(d_{i-1}\),变量 ans
记录 \(\max\{d_i\}\),并在扫描时动态更新即可。
时间复杂度 \(O(n)\),空间复杂度 \(O(1)\)。
核心代码如下:
maxn = int(2e5 + 5)
arr = [0 for _ in range(maxn)] # 从下标 1 开始存
# 输入过程略……
ans = None
last_d = 0
for i in range(1, n + 1):
temp_ans = max(last_d, 0) + arr[i]
if ans is None or temp_ans > ans:
ans = temp_ans
last_d = temp_ans
print(ans)
前缀和做法
将数列前 \(n\) 项的和记为 \(sum_n\):
可以用前缀和快速求区间和:
用 \(d_i\) 表示结尾为位置 \(i\) 的最大区间和,则有
问题的答案即为 \(\max\{d_i \mid i \in [1,n]\}\)。
编写代码时只需要开前缀和数组,无需开 \(d\) 数组,用变量 cur_min_pre_sum
记录 \(\min\{sum_j\}\),变量 ans
记录 \(\max\{d_i\}\),并动态维护即可。
时间复杂度 \(O(n)\),空间复杂度 \(O(n)\)。
核心代码如下:
maxn = int(2e5 + 5)
arr = [0 for _ in range(maxn)] # 原数组,从下标 1 开始存
pre_sum = [0 for _ in range(maxn)] # 前缀和数组
# 输入过程略……
# 预处理前缀和
for i in range(1, n + 1):
pre_sum[i] = pre_sum[i - 1] + arr[i]
cur_min_pre_sum = 0
ans = None
for i in range(1, n + 1):
temp_ans = pre_sum[i] - cur_min_pre_sum
if ans is None or temp_ans > ans:
ans = temp_ans
cur_min_pre_sum = min(cur_min_pre_sum, pre_sum[i])
print(ans)
分治做法
若有一区间 \([start,stop)\),区间中点为 \(mid\),其最大子段和对应的子区间为 \([i,j)\),则 \([i,j)\) 只有以下三种情况:
- \([i,j)\) 完全在左子区间 \([start,mid)\) 内;
- \([i,j)\) 完全在右子区间 \([mid,stop)\) 内;
- \([i,j)\) 横跨中点 \(mid\)。
求出这三种情况下的值,取最大的即可。
前两种情况可通过递归求解,求解第三种情况需要一点技巧,方法是从中点出发分别向左右两边延伸。
时间复杂度 \(O(n\log n)\)。
核心代码如下:
maxn = int(2e5 + 5)
arr = [0 for _ in range(maxn)] # 从下标 1 开始存
# 从位置 mid - 1 开始向左延伸的最大区间和
# 注:左子区间 [start, mid)
def mid_lmax(start: int, mid: int) -> int:
ans = None
cur_sum = 0
for i in range(mid - 1, start - 1, -1):
cur_sum += arr[i]
if ans is None or cur_sum > ans:
ans = cur_sum
return ans
# 从位置 mid 开始向右延伸的最大区间和
# 注:右子区间 [mid, stop)
def mid_rmax(mid: int, stop: int) -> int:
ans = None
cur_sum = 0
for i in range(mid, stop):
cur_sum += arr[i]
if ans is None or cur_sum > ans:
ans = cur_sum
return ans
# [start, stop) 的最大子区间和
def solve(start: int, stop: int) -> int:
if stop - start == 1:
return arr[start]
mid = (start + stop) // 2
only_lmax = solve(start, mid) # 完全在左子区间内
only_rmax = solve(mid, stop) # 完全在右子区间内
span_max = mid_lmax(start, mid) + mid_rmax(mid, stop) # 横跨中点
return max(only_lmax, only_rmax, span_max)