P10220 [省选联考 2024] 迷宫守卫 题解
题意简述
一棵完全二叉树,对于每一个节点,Alice 可以花费一定的代价,强制 Bob 先往左子树走,求到达叶子结点的最小字典序。
题目分析
算法:贪心
我们意识到到一点:比较两个排列的字典序大小时,只需要比较他们不相同的第一位即可。也就是说,对于任意一个状态,Alice 肯定会不惜一切代价(当然是在 \(K\) 的范围内)使走到的第一个叶节点最小。所以我们先考虑如何找出第一个节点,剩下的就好设计了。(我们使用线段树的方式遍历整棵树,下面左孩子称为 \(i\times2\) ,右孩子为 \(i\times2+1\))
我们对每一个点开一个 vector,\(v_{i,j}\) 代表从 \(i\) 开始,Alice 通过激活石像,强制 Bob 走到 \(j\) 的最小代价(即 Bob 从 \(i\) 点出发的所有方案中,第一个到达的节点最小是 \(j\) )。
对于每一个节点 \(i\),我们遍历它下面的所有叶子节点,更新答案。对于左子树中的叶子节点 \(j\), 有两种方式:
- 先强制使 Bob 走到左子树,再加上左子树的答案,即:
- 不激活此处神像,但是 Bob 走到右子树时,会不得不走到一个 \(q\) 值更大的点。(此处顺便记录用来威慑 Bob 的那一个点 \(k\) 以便统计答案)即:
对于右子树,只需考虑后者即可。
这样我们就可以简单地算出从以任意一个节点为第一个到达的节点的最小代价,只需要找到这个代价小于等于 \(K\) 的中,\(q\) 值最大的那个即可。这样我们就找到了第一个点。
对于后面的点:我们定义 \(Solve(i,need)\) 为从第 \(i\) 个点出发并且此时要求第一个到达的点是 \(need\) 。调用函数时,遇到叶子节点就直接输出,对于其他节点:
- 首先向 \(need\) 所在子树递归,如 \(need\) 在左子树则调用 \(Solve(i\times2,need)\) 。
- 对于另一个子树。我们发现我们第一遍遍历时只求了代价最小的方案而非字典序最大的方案,所以先撤销掉原来的方案(当时记录的 \(k\) 就派上用场啦)再找一个代价仍在范围内的,字典序最大的点,将 \(need\) 改为这个点再遍历。
- 还有一种特殊情况
考场上没想到,然后大样例挂了。就是对于左子树,我们可能在原来的方案中使用的是点亮神像而非威慑,担当我们选取字典序最大的点时,这个点可能又会起到威慑作用,这样就可以把原来的 \(w_i\) 扔掉了(具体细节详见代码)。
这样这道题就差不多做完了 代码_75pts
实现 & 优化
我们发现这样做是 \(O(n^2\log(n))\) 的,我们只需将vector按 \(q\) 值排序,统计答案的后缀最小值,二分查找即可。时间复杂度为 \(O(n\log^2n)\) (用归并可以再少一个 \(\log\))
记得多测清空。
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+5;
const long long INF=1e18;
int T,n;
long long w[N],K;
int q[N];
struct node{long long val;int er,id;};
vector<node> v[N];
vector<long long> minn[N];
vector<int> mnx[N];
vector<int> rk[N];
vector<int> num;
bool cmpq(node a,node b){return q[a.id]<q[b.id];}
bool cmpid(node a,node b){return a.id<b.id;}
void DFS(int i,int l,int r){
if(l==r){
v[i].push_back((node){0,0,l});
minn[i].push_back(0);minn[i].push_back(INF);
mnx[i].push_back(l);mnx[i].push_back(0);
rk[i].push_back(0);
return ;
}
int mid=(l+r)>>1;
DFS(i*2,l,mid);
DFS(i*2+1,mid+1,r);
for(int j=l;j<=mid;j++){
long long mn;int er;
int pos=lower_bound(v[i*2+1].begin(),v[i*2+1].end(),(node){0,0,j},cmpq)-v[i*2+1].begin();
mn=minn[i*2+1][pos],er=mnx[i*2+1][pos];
v[i].push_back((node){0,0,j});
v[i][j-l].val=min(w[i],mn)+v[i*2][rk[i*2][j-l]].val;
if(mn<w[i])v[i][j-l].er=er;
if(v[i][j-l].val>INF)v[i][j-l].val=INF;
}
for(int j=mid+1;j<=r;j++){
long long mn;int er;
int pos=lower_bound(v[i*2].begin(),v[i*2].end(),(node){0,0,j},cmpq)-v[i*2].begin();
mn=minn[i*2][pos],er=mnx[i*2][pos];
v[i].push_back((node){0,0,j});
v[i][j-l].val=mn+v[i*2+1][rk[i*2+1][j-mid-1]].val;
v[i][j-l].er=er;
if(v[i][j-l].val>INF)v[i][j-l].val=INF;
}
sort(v[i*2].begin(),v[i*2].end(),cmpid);
sort(v[i*2+1].begin(),v[i*2+1].end(),cmpid);
sort(v[i].begin(),v[i].end(),cmpq);
for(int j=l;j<=r;j++)
minn[i].push_back(0),mnx[i].push_back(0),rk[i].push_back(0);
for(int j=r;j>=l;j--){
rk[i][v[i][j-l].id-l]=j-l;
if(j==r){
minn[i][j-l]=v[i][j-l].val;
mnx[i][j-l]=v[i][j-l].id;
continue;
}
minn[i][j-l]=minn[i][j-l+1];
mnx[i][j-l]=mnx[i][j-l+1];
if(v[i][j-l].val<minn[i][j-l+1]){
minn[i][j-l]=v[i][j-l].val;
mnx[i][j-l]=v[i][j-l].id;
}
}
minn[i].push_back(INF);
mnx[i].push_back(0);
}
int Find(int i,int l,int r){
int ans=0;
for(int j=l;j<=r;j++){
if(v[i][j-l].val<=K&&q[ans]<q[j])ans=j;
}
K-=v[i][ans-l].val;
return ans;
}
void Solve(int i,int l,int r,int need){
if(l==r)return num.push_back(q[l]),void();
int mid=(l+r)>>1;
if(need<=mid){
Solve(i*2,l,mid,need);
int newrt;
if(v[i][need-l].er)K+=v[i*2+1][v[i][need-l].er-mid-1].val,newrt=Find(i*2+1,mid+1,r);
else{
K+=w[i];newrt=Find(i*2+1,mid+1,r);
if(q[newrt]<q[need]){
K+=v[i*2+1][newrt-mid-1].val-w[i];
newrt=Find(i*2+1,mid+1,r);
}
}
Solve(i*2+1,mid+1,r,newrt);
}
else{
Solve(i*2+1,mid+1,r,need);
if(v[i][need-l].er)K+=v[i*2][v[i][need-l].er-l].val;
int newrt=Find(i*2,l,mid);
Solve(i*2,l,mid,newrt);
}
}
int main(){
scanf("%d",&T);
while(T--){
scanf("%d%lld",&n,&K);
for(int i=1;i<=(1<<n)-1;i++)scanf("%lld",&w[i]);
for(int i=1;i<=(1<<n);i++)scanf("%d",&q[i]);
DFS(1,1,(1<<n));
sort(v[1].begin(),v[1].end(),cmpid);
int mn=Find(1,1,(1<<n));
Solve(1,1,(1<<n),mn);
for(int i=0;i<(1<<n);i++){
printf("%d",num[i]);
if(i!=(1<<n)-1)printf(" ");
}
printf("\n");
for(int i=1;i<=(1<<n+1)-1;i++)v[i].clear();
for(int i=1;i<=(1<<n+1)-1;i++)minn[i].clear();
for(int i=1;i<=(1<<n+1)-1;i++)mnx[i].clear();
for(int i=1;i<=(1<<n+1)-1;i++)rk[i].clear();
num.clear();
}
return 0;
}