【CF338E】Optimize! - 线段树
题目描述
Manao came up with a solution that produces the correct answers but is too slow. You are given the pseudocode of his solution, where the function getAnswer calculates the answer to the problem:
getAnswer(a[1..n], b[1..len], h)
answer = 0
for i = 1 to n-len+1
answer = answer + f(a[i..i+len-1], b, h, 1)
return answer
f(s[1..len], b[1..len], h, index)
if index = len+1 then
return 1
for i = 1 to len
if s[index] + b[i] >= h
mem = b[i]
b[i] = 0
res = f(s, b, h, index + 1)
b[i] = mem
if res > 0
return 1
return 0
Your task is to help Manao optimize his algorithm.
题目大意
当两个长度相同的数组存在一种两两匹配方式,使得每一对的和不小于 \(h\),则称之为匹配。
有一个长为 \(m\) 的数组 \(B\),求 \(A\) 有多少长度为 \(m\) 的子序列(连续)中的数和 \(B\) 中的数匹配。
思路
两个数组按照以下方式匹配一定是最优的:
令 \(A\) 的子序列为 \(S\)
\(B\) 最大\(\to S\) 最小,\(B\) 次大\(\to S\) 次小,...,\(B\) 最小\(\to S\) 最大
因为假设 \(S\) 有两数 \(s_1 \le s_2\),\(B\) 有两数 \(b_1 \le b_2\)
不按大对小的匹配方式则有
且此时大对小的匹配方式不满足
若 \(s_1+b_2 < h\ \text{③}\)
则 \(\text{①-③} \implies b_1 > b_2\) 矛盾
同理若 \(s_2+b_1 < h\ \text{④}\)
则 \(\text{①-④} \implies s_1 > s_2\) 矛盾
所以 \(B\) 中的最大至少有 \(m\) 个数在 \(S\) 中能匹配,\(B\) 中的次大至少有 \(m-1\) 个数在 \(S\) 中能匹配,...,\(B\) 中的最小至少有 \(1\) 个数在 \(S\) 中能匹配
将 \(B\) 从小到大排序,用线段树维护 \(B\) 中每个数能被多少个数匹配,将初始可匹配数设为 -1,-2,...,-m
每次将 \(a_i\) 加入时,找到第一个 \(a_i+b_j \ge h\) 的位置 \(j\),并将 \(B\) 中 \([j,m]\) 的可匹配数加 1
若线段树维护的可匹配数最小值大于 0,这个子序列就是合法的
#include <functional>
#include <algorithm>
#include <cstdio>
using namespace std;
const int maxn = 1.5e5 + 10;
int n,m,h,a[maxn],b[maxn],minv[maxn<<2],laz[maxn<<2];
inline void pushdown(int root) {
if (laz[root]) {
laz[root<<1] += laz[root];
laz[root<<1|1] += laz[root];
minv[root<<1] += laz[root];
minv[root<<1|1] += laz[root];
laz[root] = 0;
}
}
inline void pushup(int root) { minv[root] = min(minv[root<<1],minv[root<<1|1]); }
inline void update(int ul,int ur,int x,int l = 1,int r = m,int root = 1) {
if (l > ur || r < ul) return;
if (ul <= l && r <= ur) return laz[root] += x,minv[root] += x,void();
int mid = l+r>>1;
pushdown(root);
update(ul,ur,x,l,mid,root<<1);
update(ul,ur,x,mid+1,r,root<<1|1);
pushup(root);
}
int main() {
scanf("%d%d%d",&n,&m,&h);
for (int i = 1;i <= m;i++) { scanf("%d",&b[i]); update(i,m,-1); }
sort(b+1,b+m+1);
for (int i = 1;i <= n;i++) {
scanf("%d",&a[i]);
a[i] = lower_bound(b+1,b+m+1,h-a[i])-b;
if (i <= m) update(a[i],m,1);
}
int ans = minv[1] >= 0;
for (int i = m+1;i <= n;i++) {
update(a[i],m,1);
update(a[i-m],m,-1);
ans += minv[1] >= 0;
}
printf("%d",ans);
return 0;
}