AHOI2022 回忆 题解

洛谷:[AHOI2022] 回忆


题解:事先声明,该做法没有经过严格的证明,我也不清楚它到底对不对,但是它可以通过省选原题的所有数据,欢迎提供证明或者提供 Hack。

首先,显然对于每一个点 \(u\) 只有一条 \((s_i,t_i)\) 使得 \(t_i=u\),若真的有多个,取 \(s_i\) 最浅的。

我们考虑贪心,在以 \(u\) 为根的子树中,存在若干条已经完成的路径和若干条向上的路径等待匹配,等待匹配的路径分为 \(s_i\)\(u\) 之上(即目前还不能匹配),和 \(s_i\)\(u\) 或者在 \(u\) 之下(即目前已经可以匹配),我们使用数据结构维护目前还不能匹配的深度的可重集合,显然,我们希望找到一种方式使得它目前可以匹配的路径数量既可能多,且对于不能匹配的路径 \(s_i\) 的深度都尽可能深。

接下来我们考虑合并自己的孩子集合并将其匹配。

如果当前孩子可以进行匹配的向上路径恰好能够完全匹配(或者只剩一条),那么我们直接进行匹配即可。

否则定然有一个子树的未匹配路径数比其它子树加起来都多,那么我们考虑拆掉其它子树内部已经匹配的路径,再进行匹配。

最后我们要加入一条 \(t_i=u\) 的路径,如果我们目前还不可以匹配的路径集合不为空,那么将它和集合中最浅的合并(根据不能匹配的路径 \(s_i\) 的深度都尽可能深的原则),否则,如果当前子树内存在目前可以匹配但未匹配的路径,将当前点的路径与其匹配,若以上两种情况都不满足,再判断子树内是否有已经匹配的路径,将其拆开,其中一条与当前路径匹配,另一条作为目前可以匹配但未匹配的路径。

由于我们的数据结构需要支持查询最小值,查询最大值,删除最小值,删除最大值,以及合并,我在考场上偷懒写了 set + 启发式合并,得到了 \(O(n\log^2 n)\) 的时间复杂度,然而如果用四个可并堆模拟或者用线段树合并就可以做到 \(O(n\log n)\) 的时间复杂度。

时间复杂度:\(O(n\log^2 n)\)\(O(n\log n)\)

代码:

#include <set>
#include <vector>
#include <cstdio>
#include <algorithm>
const int Maxn=200000;
int n,m;
int head[Maxn+5],arrive[Maxn<<1|5],nxt[Maxn<<1|5],tot;
void add_edge(int from,int to){
	arrive[++tot]=to,nxt[tot]=head[from],head[from]=tot;
}
int s_min[Maxn+5];
int dep[Maxn+5],fa[Maxn+5];
void init_dfs(int u){
	dep[u]=dep[fa[u]]+1;
	for(int i=head[u];i;i=nxt[i]){
		int v=arrive[i];
		if(v==fa[u]){
			continue;
		}
		fa[v]=u;
		init_dfs(v);
	}
}
std::multiset<int> st[Maxn+5];
int match[Maxn+5],out[Maxn+5],putt[Maxn+5];
int num;
void work_dfs(int u){
	match[u]=out[u]=putt[u]=0;
	st[u].clear();
	std::vector<std::pair<int,int> > o_lis;
	for(int i=head[u];i;i=nxt[i]){
		int v=arrive[i];
		if(v==fa[u]){
			continue;
		}
		work_dfs(v);
		match[u]+=match[v];
		while(!st[v].empty()&&*(--st[v].end())>=dep[u]){
			out[v]++,st[v].erase(--st[v].end());
		}
		o_lis.push_back(std::make_pair(out[v],match[v]));
		if(st[u].size()<st[v].size()){
			std::swap(st[u],st[v]);
		}
		for(auto it:st[v]){
			st[u].insert(it);
		}
		st[v].clear();
	}
	if(!o_lis.empty()){
		std::sort(o_lis.begin(),o_lis.end());
		int sum=0;
		for(int i=0;i<(int)o_lis.size()-1;i++){
			sum+=o_lis[i].first+o_lis[i].second*2;
		}
		if(sum>=o_lis.back().first){
			sum=0;
			for(int i=0;i<(int)o_lis.size()-1;i++){
				sum+=o_lis[i].first;
			}
			if(sum>=o_lis.back().first){
				sum+=o_lis.back().first;
				out[u]=sum%2,match[u]+=sum/2;
			}
			else{
				int tmp=o_lis.back().first;
				match[u]+=sum;
				tmp-=sum;
				match[u]+=tmp/2,out[u]=tmp%2;
			}
		}
		else{
			match[u]=sum+o_lis.back().second;
			out[u]=o_lis.back().first-sum;
		}
	}
	if(s_min[u]!=n+1){
		if(st[u].empty()){
			st[u].insert(s_min[u]);
			if(out[u]>0){
				out[u]--,num--;
			}
			else if(match[u]>0){
				match[u]--,out[u]++,num--,putt[u]++;
			}
		}
		else{
			int val=std::min(*st[u].begin(),s_min[u]);
			st[u].erase(st[u].begin());
			st[u].insert(val);
			num--;
		}
	}
}
void solve(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		s_min[i]=n+1,head[i]=0;
	}
	tot=0;
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		add_edge(u,v),add_edge(v,u);
	}
	init_dfs(1);
	num=0;
	for(int i=1;i<=m;i++){
		int s,t;
		scanf("%d%d",&s,&t);
		s_min[t]=std::min(s_min[t],dep[s]);
	}
	for(int i=1;i<=n;i++){
		if(s_min[i]!=n+1){
			num++;
		}
	}
	work_dfs(1);
	printf("%d\n",num-match[1]);
}
int main(){
	int T;
	scanf("%d",&T);
	while(T--){
		solve();
	}
	return 0;
}
posted @ 2022-05-09 23:54  with_hope  阅读(316)  评论(0编辑  收藏  举报