Loj#2324-「清华集训 2017」小 Y 和二叉树

正题

题目链接:https://loj.ac/p/2324


题目大意

给出\(n\)个点的一棵树,每个点的度数不超过\(3\)

你要求它的一个二叉树结构(根任意选择)使得其中序遍历的字典序最小。

\(1\leq n\leq 10^6\)


解题思路

直接找根感觉比较麻烦,我们考虑先确定中序遍历中的第一个点。

显然这个点是最小的一个度数不为\(3\)的节点,我们设为\(x\)

此时\(x\)的左子树肯定没有节点,然后考虑它连接的点安排到右子树或者父节点。

先假设\(x\)的度数是\(2\),因为下一步是遍历\(x\)的右子树,所以我们优先比对两个连接的部分作为子树时字典序最小的第一个数是啥。

因为每个点的度数都不超过\(3\),整张图可能出现的子树(不同的根)数量为\(2n-2\)个(每条边的两个方向),我们可以先预处理出每个子树字典序最小时第一个是啥。

这样我们就能快速比较了。

然后考虑\(x\)的度数是\(1\)的时候,记它连接的节点是\(y\)

  • \(y\)的度数为\(0\),那随便丢哪都一样。
  • \(y\)的度数为\(1\),显然\(y\)子节点时可以控制它和它子树的顺序,丢右子树肯定最优。
  • \(y\)的度数为\(2\),考虑丢右子树的优势是它一定可以控制左右两棵子树的顺序,而丢父节点的优势是可以优先把\(y\)遍历掉,我们比较\(y\)\(y\)作为儿子时子树中最小字典序的第一个数,这样就可以确定丢哪了。

至于被丢到子树里面的,我们上面的预处理可以确定子树里面的顺序。

时间复杂度:\(O(n)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e6+10;
int n,tot,rt,k[N],f[N][3],g[N][3],t[N][2];
bool v[N];
void dfs(int x,int fa,int p){
	if(f[fa][p]<=n+1)return;
	if(k[x]==3)f[fa][p]=n+1;
	else f[fa][p]=x;
	for(int i=0;i<k[x];i++){
		int y=g[x][i];
		if(y==fa)continue;dfs(y,x,i);
		f[fa][p]=min(f[fa][p],f[x][i]);
	}
	return;
}
void del(int x,int y){
	for(int i=0;i<k[x];i++)
		if(g[x][i]==y){
			swap(g[x][i],g[x][k[x]-1]);
			swap(f[x][i],f[x][k[x]-1]);
			k[x]--;
		}
	return;
}
void solve(int x){
	rt=x;v[x]=1;
	if(!k[x])return;
	if(k[x]==1){
		int y=g[x][0];
		if(k[y]==3&&y<f[x][0])
			del(y,x),t[y][0]=x,solve(y);
		else del(y,x),t[x][1]=y;
	}
	else{
		if(f[x][0]>f[x][1])swap(g[x][0],g[x][1]);
		int y=g[x][1];t[x][1]=g[x][0];
		del(y,x);del(g[x][0],x);
		t[y][0]=x;solve(g[x][1]);
	}
	return;
}
void print(int x){
	if(!x)return;
	if(v[x]){
		print(t[x][0]);
		printf("%d ",x);
		print(t[x][1]);
	}
	else if(k[x]==2){
		if(f[x][0]>f[x][1])swap(g[x][0],g[x][1]);
		del(g[x][0],x);print(g[x][0]);
		printf("%d ",x);
		del(g[x][1],x);print(g[x][1]);
	}
	else if(k[x]==1){
		if(x<f[x][0])printf("%d ",x);
		del(g[x][0],x);print(g[x][0]);
		if(x>f[x][0])printf("%d ",x);
	}
	else printf("%d ",x);
}
int main()
{
	memset(f,0x3f,sizeof(f));
	scanf("%d",&n);tot=1;
	for(int i=1;i<=n;i++){
		scanf("%d",&k[i]);
		for(int j=0,x;j<k[i];j++)scanf("%d",&g[i][j]);
	}
	for(int i=1;i<=n;i++)
		for(int j=0;j<k[i];j++)dfs(g[i][j],i,j);
	int x=0;
	for(int i=1;i<=n;i++)
		if(k[i]<=2){x=i;break;}
	solve(x);
	print(rt);
	return 0;
}
posted @ 2022-07-15 18:49  QuantAsk  阅读(44)  评论(0编辑  收藏  举报