The 2023 ICPC Asia Hong Kong Regional Programming Contest (The 1st Universal Cup, Stage 2:Hong Kong) L. Permutation Compression 线段树二分
没学过线段树二分,也不知道这个东西算不算线段树二分,感觉蛮像的。
大致题意:
T组测试数据,每组第一行给出n,m,k三个数字。第二行输入大小为n的排列。第三行输入大小为m的一行数字,第四行输入k个数字。要求每次从n里面选择一段连续的区间,删去最大值,从而把第二行变成第三行。第四行是魔法,对于第四行的每一个数字x,表示可以删掉大小为x的区间的最大值,每个魔法只能用一次,最终询问是否能把第一个排列变成和m个数字一模一样。
解题思路:
1. 如果n和m里面出现的数字位置先后顺序不一样无解。
2. 如果删除魔法的数量比n和m的差值小,无解。
3. 我们观察发现如果我们如果选择一个最大的数字进行删除,也就是我们可以选择整个区间,那么所有小于等于这个区间的魔法都可以把他删除掉,也就是说区间选择的越大,魔法的适配性越高,所以考虑从不在m里面的数字从大到小贪心出最大的合法区间。我们使用线段树进行查找,最朴素的想法就是基于当前的值二分在线段树上查找这个位置左右的区间最大值,但是这么做的时间复杂度是n * logn * logn * logn(因为左右两边都要二分是两个二分)的,很明显过不了这个题。
那我们换个思路,能不能在线段树上直接找到左边的第一个不合法位置,和右边的第一个不合法位置。对于查找左边第一个不合法我们就可以这么写。
我们在线段树上维护区间最大值。为了保证区间信息一定是在左边所以出口先设置x在区间左边直接递归左区间,如果不满足上一个条件,那么我们可以判断右儿子的最大值是不是一定没有比当前的值v大,如果成立继续递归左儿子,如果不成立我们需要先递归右儿子的合法区间里面的值是不是都小于等于v(因为右儿子不一定全都是合法区间),如果成立递归左儿子,不成立直接返回先前递归的右儿子信息即可。对于查找右边第一个不合法区间同理。由于我们需要一个第一个不合法位置,递归合法区间的时候还需要一个最大值,所以用二元组类型的返回值。
为了简化代码,我们在0和n+1地方插入两个0xf3f3f3f3f大小的哨兵,这样的话每次只需要l++, r --,那么[l, r]就是一个合法的区间。每次删完当前值,需要把这个点变成0,方便我们统计区间不为0的数字的数量。由于每个区间有些数字可能被删除了,被删除的是不能算进来的,所以我们统计一个值不为0的数字的数量,也就是线段树上在维护一个sum总和。每次query一下[l, r]区间不为0的数量就是能使用的最大魔法的值,小于等于这个数量的魔法都是可以用的。
我们先把所有的魔法丢进一个multiset,然后把所有求出来的最大魔法值压进int类型的vector ans进行排序,我们开始从multiset和ans最开始的地方开始走,一旦出现mutiset里面的值大于ans里面的值我们就可以输出NO了,如果ans走完了都没有这种情况,我们输出YES即可。
VP的时候花了3个小时写了出来其实还是蛮有成就感。
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define se second
#define AC main(void)
#define ls u << 1
#define rs u << 1 | 1
typedef std::pair<int, int> PII;
const int N = 2e5 + 10;
const int INF = 0x3f3f3f3f;
inline int max(int a, int b){return a > b ? a : b;}
inline int min(int a, int b){return a > b ? b : a;}
int n, m, k;
int a[N], b[N], pos[N], c[N];
bool vis[N];
struct node {
int mx;
int sum;
}tr[N << 2];
inline void pushup(int u) {
tr[u].mx = max(tr[ls].mx, tr[rs].mx);
tr[u].sum = tr[ls].sum + tr[rs].sum;
}
inline void build(int u, int l, int r) {
if(l == r){
tr[u].mx = a[l];
tr[u].sum = 1;
return ;
}
int mid = l + r >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
pushup(u);
}
inline void modify(int u, int L, int R, int x, int v) {
if(L == R){
tr[u].mx = v;
tr[u].sum = 0;
return ;
}
int mid = L + R >> 1;
if(x <= mid) modify(ls, L, mid, x, v);
else modify(rs, mid + 1, R, x, v);
pushup(u);
}
inline PII queryl(int u, int L, int R, int x, int v) {
if(L == R) return {L, tr[u].mx};
int mid = L + R >> 1;
if(x <= mid) return queryl(ls, L, mid, x, v);
if(tr[rs].mx <= v) return queryl(ls, L, mid, x, v);
PII right = queryl(rs, mid + 1, R, x, v);
if(right.second <= v) return queryl(ls, L, mid, x, v);
return right;
}
inline PII queryr(int u, int L, int R, int x, int v) {
if(L == R) return {L, tr[u].mx};
int mid = L + R >> 1;
if(x > mid) return queryr(rs, mid + 1, R, x, v);
if(tr[ls].mx <= v) return queryr(rs, mid + 1, R, x, v);
PII left = queryr(ls, L, mid, x, v);
if(left.second <= v) return queryr(rs, mid + 1, R, x, v);
return left;
}
inline int querysum(int u, int L, int R, int l, int r) {
if(L >= l && R <= r) return tr[u].sum;
int sum = 0;
int mid = L + R >> 1;
if(l <= mid) sum += querysum(ls, L, mid, l, r);
if(r > mid) sum += querysum(rs, mid + 1, R, l, r);
return sum;
}
inline void solve() {
std::cin >> n >> m >> k;
a[0] = INF;
a[n + 1] = INF;
for(int i = 1; i <= n; i ++) {
std::cin >> a[i];
pos[a[i]] = i;
vis[i] = false;
}
int last = -1;
bool ok = true;
for(int i = 1; i <= m; i ++) {
int x;
std::cin >> x;
if(pos[x] < last) ok = false;
last = pos[x];
vis[x] = true;
}
std::multiset<int> st;
for(int i = 1; i <= k; i ++){
int x;
std::cin >> x;
st.insert(x);
}
if(n - m > k || !ok){
std::cout << "NO" << '\n';
return ;
}
build(1, 0, n + 1);
std::vector<int> ans;
for(int i = n; i; i --) {
if(vis[i]) continue;
int target = pos[i];//找到位置
int x1 = pos[i], x2 = pos[i];
int l = queryl(1, 0, n + 1, x1, i).first;
int r = queryr(1, 0, n + 1, x2, i).first;
r --; l ++;
int sum = querysum(1, 0, n + 1, l, r);
modify(1, 0, n + 1, target, 0);
ans.push_back(sum);
}
std::sort(ans.begin(), ans.end());
ok = true;
int i = 0, sz = ans.size();
for(std::multiset<int>::iterator it = st.begin(); i < sz; i ++, it ++) {
if(*it > ans[i]) {
ok = false;
break;
}
}
if(ok) std::cout << "YES" << '\n';
else std::cout << "NO" << '\n';
}
int main(void){
std::ios::sync_with_stdio(false);
std::cin.tie(0);
std::cout.tie(0);
int _ = 1;
std::cin >> _;
while(_ --) solve();
return 0;
}