蓝桥杯真题(数据结构)
数据结构
最长不下降子序列
设输入的序列为arr。
考虑将长度为K的区间[l, r]
变为相同值,应该变成什么值比较好呢?不是很好考虑。
换一种视角,我们最终的最长不下降子序列应该包含三段,中间那一段是长度为K的区间,两边应该是由DP得来的。用dp1[i]
表示以i结尾的最长不下降子序列长度,用dp2[i]
表示以i开始的最长不下降子序列长度,
因此对于每个dp1[i]
,我们需要找到一个最大的dp2[j]
,并且j - i >= K + 1, arr[j] >= arr[i]
,这样构成的最长不下降子序列长度为dp1[i] + K + dp2[j]
具体做法就是长度为K的滑动窗口,窗口区间为[l,r]
。对于dp1[l-1]
要计算出max(dp2[j] for j in range(r+1,N) if arr[j] >= arr[i])
。
还需要记得考虑一种边界情况,就是窗口滑到边界时,无法构成三段,要额外计算一下。
为了避免超时,需要快速查询最大的满足要求的dp2[j]
。注意到这里的要求是arr[j] >= arr[i]
,所以可以用权值线段树,线段树维护的区间是离散化后的arr,区间[l,r]
的值为max(dp[i]), l <= i <= r
。
import sys; readline = sys.stdin.readline
read = lambda: [int(x) for x in readline().split()]
alloc = lambda *s: len(s) != 1 and [alloc(*s[1:]) for i in range(int(s[0]) + 2)] or [0] * int(s[0] + 2)
sys.setrecursionlimit(int(1e9))
N, K = read()
# 离散化
arr = read()
nums = sorted(list(set(arr)))
num_i = {v:i for i,v in enumerate(nums)}
arr = [num_i[v] for v in arr]
M = max(arr)
tr = alloc(N * 4) # 用来查询[l, r]最大的dp[i],l <= i <= r
def pushup(u): tr[u] = max(tr[u << 1], tr[u << 1 | 1])
def build(u, l, r):
if l == r: tr[u] = 0
else:
mid = l + r >> 1
build(u << 1, l, mid)
build(u << 1 | 1, mid + 1, r)
pushup(u)
def query(u, l, r, ql, qr):
if ql <= l and r <= qr: return tr[u]
mid = l + r >> 1
a = 0
if ql <= mid: a = max(a, query(u << 1, l, mid, ql, qr))
if qr >= mid + 1: a = max(a, query(u << 1 | 1, mid + 1, r, ql, qr))
return a
def update(u, l, r, pos, val):
if l == pos and r == pos:
tr[u] = max(tr[u], val)
return
mid = l + r >> 1
if pos <= mid: update(u << 1, l, mid, pos, val)
if pos >= mid + 1: update(u << 1 | 1, mid + 1, r, pos, val)
pushup(u)
dp1 = alloc(N); dp2 = alloc(N)
build(1, 0, M)
for i, v in enumerate(arr):
dp1[i] = query(1, 0, M, 0, v) + 1
update(1, 0, M, v, dp1[i])
build(1, 0, M)
ans = K
# i - K + 1 min to 0
for i in range(len(arr) - 1, K - 1 -1, -1):
v = arr[i]
if i - K >= 0: ans = max(ans, dp1[i - K] + K + query(1, 0, M, arr[i - K], M))
dp2[i] = query(1, 0, M, v, M) + 1
update(1, 0, M, v, dp2[i])
ans = max(ans, dp1[N - K - 1] + K)
ans = max(ans, dp2[K] + K)
print(ans)
最大和
这题如果用正常思路写出来是这样的:
import sys; readline = sys.stdin.readline
read = lambda: [int(x) for x in readline().split()]
alloc = lambda *s: len(s) != 1 and [alloc(*s[1:]) for i in range(int(s[0]) + 2)] or [0] * int(s[0] + 2)
N, = read()
primes = []
sieved = alloc(N)
min_prime = dict() # min_prime[i] 表示i的最小质因数
min_prime[1] = 1
for i in range(2, N + 1): # 做一遍线性筛法
if not sieved[i]:
primes.append(i)
min_prime[i] = i # 质数的最小质因数是它本身
j = 0
while i * primes[j] <= N:
sieved[i * primes[j]] = True
min_prime[i * primes[j]] = primes[j]
if i % primes[j] == 0: break # 为了保证primes[j]是i * primes[j]的最小质因数
j += 1
arr = [0] + read()
f = alloc(N * 2)
f[1] = arr[1]
for i in range(2, N + 1): f[i] = -float('inf')
for i in range(1, N):
x = min_prime[N - i]
for j in range(i + 1, min(N + 1, i + x + 1)):
f[j] = max(f[j], f[i] + arr[j])
print(f[N])
然而对于Python来说只能过90%的样例,如果是C++应该可以过的,无奈。只能想优化了。
注意到内层的循环是对一个区间来做更新,因此可以想到用线段树。对于区间修改,线段树可以使用懒标记来进行优化。这里的arr[j]
要独立出去,因为区间中每个元素的arr[j]
都不一样,不能优化。因此线段树要维护的是最大的f,也就是说tr[u] = max(f[i] for any i if i can jump to u)
,然后在计算f[i]
值的时候再额外加arr[i]
即可。
import sys; readline = sys.stdin.readline
read = lambda: [int(x) for x in readline().split()]
alloc = lambda *s: len(s) != 1 and [alloc(*s[1:]) for i in range(int(s[0]) + 2)] or [0] * int(s[0] + 2)
N, = read()
primes = []
sieved = alloc(N)
min_prime = dict()
min_prime[1] = 1
for i in range(2, N + 1):
if not sieved[i]:
primes.append(i)
min_prime[i] = i
j = 0
while i * primes[j] <= N:
sieved[i * primes[j]] = True
min_prime[i * primes[j]] = primes[j]
if i % primes[j] == 0: break
j += 1
tr = [-1e9] * (N * 4)
def pushdown(u): # 将区间的修改下推到两个子区间
tr[u << 1] = max(tr[u << 1], tr[u])
tr[u << 1 | 1] = max(tr[u << 1 | 1], tr[u])
def query(u, l, r, pos):
if l == pos and r == pos: return tr[u]
pushdown(u)
mid = l + r >> 1
if pos <= mid: return query(u << 1, l, mid, pos)
else: return query(u << 1 | 1, mid + 1, r, pos)
def update(u, l, r, ml, mr, x):
if ml <= l and r <= mr: tr[u] = max(tr[u], x) # 完全被包含,不再继续向下递归
else:
pushdown(u)
mid = l + r >> 1
if ml <= mid: update(u << 1, l, mid, ml, mr, x)
if mr >= mid + 1: update(u << 1 | 1, mid + 1, r, ml, mr, x)
arr = [0] + read()
f = alloc(N * 2)
f[1] = arr[1]
for i in range(2, N + 1): f[i] = -float('inf')
for i in range(1, N):
x = min_prime[N - i]
if i > 1: f[i] = query(1, 1, N, i) + arr[i]
update(1, 1, N, i + 1, min(N, i + x), f[i])
## update相当于优化掉了如下循环
## for j in range(i + 1, min(N + 1, i + x + 1)):
## f[j] < max(f[j], f[i] + arr[j])
print(query(1, 1, N, N) + arr[N])
扫描游戏
首先所有点按照角度顺时针排序,也就是从棒的初始位置顺时针转一圈的顺序。
假设棒位于point[pos]
处,则接下来需要查找[pos + 1, n]
中长度小于等于棒的长度的最左边的点。这可以使用线段树在实现。如果没找到,则继续在[1, pos - 1]
中查找。如果还没找到,则可以停止循环。
每次找到的新点还需要和上一个点进行比较,看是否可以同时消去,也就是看象限是否一致并且叉积为0。
这题Python只能过60%数据,用C++可以过。
import sys; readline = sys.stdin.readline
read = lambda: [int(x) for x in readline().split()]
alloc = lambda *s: len(s) != 1 and [alloc(*s[1:]) for i in range(int(s[0]) + 2)] or [0] * int(s[0] + 2)
from functools import cmp_to_key
sys.setrecursionlimit(int(1e9))
n, len_ = read()
def quadrant(p):
x, y = p[0], p[1]
if x >= 0 and y > 0: return 1
if x > 0 and y <= 0: return 2
if x <= 0 and y < 0: return 3
return 4
def cross(p1, p2):
x1, y1 = p1[0], p1[1]; x2, y2 = p2[0], p2[1]
return x1 * y2 - y1 * x2
points = [None] * n
for i in range(n):
x, y, z = read()
points[i] = x, y, z, i, x * x + y * y
def cmp(p1, p2):
x1, y1 = p1[0], p1[1]; x2, y2 = p2[0], p2[1]
q1 = quadrant(p1); q2 = quadrant(p2)
if q1 != q2: return q1 - q2
c = cross(p1, p2)
if c != 0: return c
return p1[-1] - p2[-1]
points.sort(key=cmp_to_key(cmp))
points = [None] + points
INF = float('inf')
tr = [INF] * (n * 4) # 到原点的最小值
def pushup(u): tr[u] = min(tr[u << 1], tr[u << 1 | 1])
def build(u, l, r):
if l == r:
tr[u] = points[l][-1]
else:
mid = l + r >> 1
build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r)
pushup(u)
build(1, 1, n)
def search(u, l, r, ql, qr, x): # 查找[ql,qr]中第一个小于等于x的元素
if ql <= l and r <= qr:
if tr[u] > x: return -1
if l == r: return l
mid = l + r >> 1
if tr[u << 1] <= x: return search(u << 1, l, mid, ql, qr, x)
return search(u << 1 | 1, mid + 1, r, ql, qr, x)
mid = l + r >> 1
if qr <= mid: return search(u << 1, l, mid, ql, qr, x)
if ql >= mid + 1: return search(u << 1 | 1, mid + 1, r, ql, qr, x)
pos = search(u << 1, l, mid, ql, qr, x)
if pos == -1: return search(u << 1 | 1, mid + 1, r, ql, qr, x)
return pos
def modify(u, l, r, pos, x):
if l == pos and r == pos: tr[u] = x
else:
mid = l + r >> 1
if pos <= mid: modify(u << 1, l, mid, pos, x)
else: modify(u << 1 | 1, mid + 1, r, pos, x)
pushup(u)
rank = 0; last_rank = 0
pos = 0; flag = False
ans = [-1] * n
while True:
idx = -1
if pos + 1 <= n: idx = search(1, 1, n, pos + 1, n, len_ * len_)
if idx == -1:
if pos >= 2: idx = search(1, 1, n, 1, pos - 1, len_ * len_)
if idx == -1: break
len_ += points[idx][2]
modify(1, 1, n, idx, INF)
if pos and quadrant(points[idx]) == quadrant(points[pos]) and cross(points[idx], points[pos]) == 0:
ans[points[idx][-2]] = last_rank; rank += 1
else:
rank += 1; ans[points[idx][-2]] = rank; last_rank = rank
pos = idx
print(' '.join(map(str, ans)))