AHOI2022 回忆 题解
题解:事先声明,该做法没有经过严格的证明,我也不清楚它到底对不对,但是它可以通过省选原题的所有数据,欢迎提供证明或者提供 Hack。
首先,显然对于每一个点 \(u\) 只有一条 \((s_i,t_i)\) 使得 \(t_i=u\),若真的有多个,取 \(s_i\) 最浅的。
我们考虑贪心,在以 \(u\) 为根的子树中,存在若干条已经完成的路径和若干条向上的路径等待匹配,等待匹配的路径分为 \(s_i\) 在 \(u\) 之上(即目前还不能匹配),和 \(s_i\) 是 \(u\) 或者在 \(u\) 之下(即目前已经可以匹配),我们使用数据结构维护目前还不能匹配的深度的可重集合,显然,我们希望找到一种方式使得它目前可以匹配的路径数量既可能多,且对于不能匹配的路径 \(s_i\) 的深度都尽可能深。
接下来我们考虑合并自己的孩子集合并将其匹配。
如果当前孩子可以进行匹配的向上路径恰好能够完全匹配(或者只剩一条),那么我们直接进行匹配即可。
否则定然有一个子树的未匹配路径数比其它子树加起来都多,那么我们考虑拆掉其它子树内部已经匹配的路径,再进行匹配。
最后我们要加入一条 \(t_i=u\) 的路径,如果我们目前还不可以匹配的路径集合不为空,那么将它和集合中最浅的合并(根据不能匹配的路径 \(s_i\) 的深度都尽可能深的原则),否则,如果当前子树内存在目前可以匹配但未匹配的路径,将当前点的路径与其匹配,若以上两种情况都不满足,再判断子树内是否有已经匹配的路径,将其拆开,其中一条与当前路径匹配,另一条作为目前可以匹配但未匹配的路径。
由于我们的数据结构需要支持查询最小值,查询最大值,删除最小值,删除最大值,以及合并,我在考场上偷懒写了 set + 启发式合并,得到了 \(O(n\log^2 n)\) 的时间复杂度,然而如果用四个可并堆模拟或者用线段树合并就可以做到 \(O(n\log n)\) 的时间复杂度。
时间复杂度:\(O(n\log^2 n)\) 或 \(O(n\log n)\)。
代码:
#include <set>
#include <vector>
#include <cstdio>
#include <algorithm>
const int Maxn=200000;
int n,m;
int head[Maxn+5],arrive[Maxn<<1|5],nxt[Maxn<<1|5],tot;
void add_edge(int from,int to){
arrive[++tot]=to,nxt[tot]=head[from],head[from]=tot;
}
int s_min[Maxn+5];
int dep[Maxn+5],fa[Maxn+5];
void init_dfs(int u){
dep[u]=dep[fa[u]]+1;
for(int i=head[u];i;i=nxt[i]){
int v=arrive[i];
if(v==fa[u]){
continue;
}
fa[v]=u;
init_dfs(v);
}
}
std::multiset<int> st[Maxn+5];
int match[Maxn+5],out[Maxn+5],putt[Maxn+5];
int num;
void work_dfs(int u){
match[u]=out[u]=putt[u]=0;
st[u].clear();
std::vector<std::pair<int,int> > o_lis;
for(int i=head[u];i;i=nxt[i]){
int v=arrive[i];
if(v==fa[u]){
continue;
}
work_dfs(v);
match[u]+=match[v];
while(!st[v].empty()&&*(--st[v].end())>=dep[u]){
out[v]++,st[v].erase(--st[v].end());
}
o_lis.push_back(std::make_pair(out[v],match[v]));
if(st[u].size()<st[v].size()){
std::swap(st[u],st[v]);
}
for(auto it:st[v]){
st[u].insert(it);
}
st[v].clear();
}
if(!o_lis.empty()){
std::sort(o_lis.begin(),o_lis.end());
int sum=0;
for(int i=0;i<(int)o_lis.size()-1;i++){
sum+=o_lis[i].first+o_lis[i].second*2;
}
if(sum>=o_lis.back().first){
sum=0;
for(int i=0;i<(int)o_lis.size()-1;i++){
sum+=o_lis[i].first;
}
if(sum>=o_lis.back().first){
sum+=o_lis.back().first;
out[u]=sum%2,match[u]+=sum/2;
}
else{
int tmp=o_lis.back().first;
match[u]+=sum;
tmp-=sum;
match[u]+=tmp/2,out[u]=tmp%2;
}
}
else{
match[u]=sum+o_lis.back().second;
out[u]=o_lis.back().first-sum;
}
}
if(s_min[u]!=n+1){
if(st[u].empty()){
st[u].insert(s_min[u]);
if(out[u]>0){
out[u]--,num--;
}
else if(match[u]>0){
match[u]--,out[u]++,num--,putt[u]++;
}
}
else{
int val=std::min(*st[u].begin(),s_min[u]);
st[u].erase(st[u].begin());
st[u].insert(val);
num--;
}
}
}
void solve(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
s_min[i]=n+1,head[i]=0;
}
tot=0;
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add_edge(u,v),add_edge(v,u);
}
init_dfs(1);
num=0;
for(int i=1;i<=m;i++){
int s,t;
scanf("%d%d",&s,&t);
s_min[t]=std::min(s_min[t],dep[s]);
}
for(int i=1;i<=n;i++){
if(s_min[i]!=n+1){
num++;
}
}
work_dfs(1);
printf("%d\n",num-match[1]);
}
int main(){
int T;
scanf("%d",&T);
while(T--){
solve();
}
return 0;
}
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。