【暑假集训模拟DAY8】树上问题

前言

树上问题以为还算了解

但现在看来确实缺的也很多,包括换根DP,树的重心,环套树等等,都不是特别熟悉

期望:0+30+50+0=80pts

实际:0+10+50+0=60pts

中间打疫苗中断也有点影响(回来就有点不想写题了)

T2标准的暴力分是20,然而我又写了一些特判的部分分(比如判链,判两点相同),反而把暴力搞炸了...

题解

T1 reform

理解倒是不难,却感觉无从下手

对于每一个点,如果原来不是重心,那么考虑切掉一个大小>n/2的部分的一部分(好像有点绕)

规定1号点为根,那么对于节点u切掉的部分可能在u子树里,也可能在u子树外,需要分别求出

如果在u子树里,比较好求出能切掉的最大的但是大小不超过n/2的真子树,同时为了下一步求出子树外的最大部分,也要记录子树内能切掉的次大部分

对于子树外的最大部分,有3种转移方式,具体见代码

还有这题用memset清空会TLE,测试了一下觉得是因为memset每次会把N大小的数组全部清空,如果T过大就会TLE;如果循环到n清空,由于n的和不会太大,所以不会TLE

总之测试数据多的时候慎用memset

代码:

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int INF = 0x3f3f3f3f,N = 4e5+10;
int head[N<<1],ecnt=-1;
int ins[N][2],outs[N];
int siz[N],n,T;
//inside,outside
//ins[u][0]表示最大子树,ins[u][1]表示次大子树 
void init()
{
	memset(head,-1,sizeof(head));
	ecnt=-1;
}
struct edge
{
	int nxt,to;
}a[N<<1];
void add(int x,int y)
{
	a[++ecnt].nxt=head[x];
	a[ecnt].to=y;
	head[x]=ecnt;
}
void dfs1(int u,int fa)//这里是回溯的时候转移ins 
{
	siz[u]=1;
	for(int i=head[u];~i;i=a[i].nxt)
	{
		int v=a[i].to;
		if(v==fa) continue;
		dfs1(v,u);
		siz[u]+=siz[v];
		int tmp;
		if(siz[v]<=n/2) tmp=siz[v];
		else tmp=ins[v][0];
		if(ins[u][0]<=tmp) ins[u][1]=ins[u][0],ins[u][0]=tmp;
		else if(ins[u][1]<=tmp) ins[u][1]=tmp;
	}//更新需要维护的2个信息:子树内最大值、次大值,为下面求子树外最大值做准备 
}
void dfs2(int u,int fa)//这里是从上往下的时候转移outs 
{
	for(int i=head[u];~i;i=a[i].nxt)
	{
		int v=a[i].to;
		if(v==fa) continue;
		//outs[v]有三个转移方式:子树内最大值,子树内次大值,子树外最大值(均为u的) 
		if(n-siz[v]<=n/2) outs[v]=n-siz[v];
		else outs[v]=outs[u];
		if(ins[v][0]==ins[u][0]||siz[v]==ins[u][0]) outs[v]=max(outs[v],ins[u][1]);
		else outs[v]=max(outs[v],ins[u][0]); 
		dfs2(v,u);
	}
}

int main()
{
	scanf("%d",&T);
	while(T--)
	{
		scanf("%d",&n);
		//init();
		ecnt=-1;
		for(int i=1;i<=n;i++) outs[i]=ins[i][1]=ins[i][0]=siz[i]=0,head[i]=head[i+n]=-1;
		for(int i=1;i<n;i++)
		{
			int u,v;
			scanf("%d%d",&u,&v);
			add(u,v),add(v,u);
		}
		dfs1(1,-1);
		dfs2(1,-1);
		for(int i=1;i<=n;i++)
		{
			bool flag=0;
			//printf("i=%d\n",i);
			if(n-siz[i]<=n/2)
			{
				for(int j=head[i];~j;j=a[j].nxt)
				{
					int v=a[j].to;
					if(siz[i]>=siz[v]&&siz[v]-ins[v][0]>n/2)//写下标的时候都要仔细思考 
						flag=1,printf("0 ");
				}
				if(!flag) printf("1 ");
			}
			else 
			{
				if(n-siz[i]-outs[i]>n/2) printf("0 ");
				else printf("1 ");
			}
		}
		printf("\n");
	}
	return 0;
}
/*
2
3
1 2
2 3
5
1 2
1 3
1 4
1 5
*/

T2 build

题意:给出一棵树上两个点,求出sigma(树上所有点到这两点中的较小距离)

暴力过不了:因为多次询问,每次询问都要重新对每个点找到最短距离,复杂度O(nmlogn)(或者大概O(nm),一时间居然忘了我考场怎么写的暴力了...)

正解:对于每个点DP求出子树内和子树外所有点到根的深度之和,倍增找到两个给定点的中点,就可以把树划分成两部分,两部分答案分别是到两个点的距离之和

注意:这里也是用到两次dfs,分别在从下往上回溯的时候和从上往下的时候转移DP

代码:

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int INF = 0x3f3f3f3f;
const int N=1e5+100;
int n,m;
int u,v,x,y;
struct node{
	int to,nxt;
}a[N<<1];
int head[N],ecnt=-1;
void add(int x,int y){
	a[++ecnt]=(node){y,head[x]};
	head[x]=ecnt;
}
int dep[N];
ll dis[N],up[N],siz[N],f[N][22];
void dfs1(int u,int fa)
{
	siz[u]=1;
	for(int i=1;i<=20;i++) f[u][i]=f[f[u][i-1]][i-1];
	for(int i=head[u];~i;i=a[i].nxt)
	{
		int v=a[i].to;
		
		if(v==fa) continue;
		dep[v]=dep[u]+1;
		f[v][0]=u;
		dfs1(v,u);
		siz[u]+=siz[v];
		dis[u]+=dis[v]+siz[v];
	}
}
void dfs2(int u,int fa)
{
	for(int i=head[u];~i;i=a[i].nxt)
	{
		int v=a[i].to;
		if(v==fa) continue;
		up[v]=up[u]+dis[u]-dis[v]-siz[v]+n-siz[v]; 
		dfs2(v,u);
	}
}
int lca(int x,int y)
{
	if(dep[x]<dep[y]) swap(x,y);
	for(int i=20;i>=0;i--)
	{
		if(f[x][i]&&dep[f[x][i]]>=dep[y]) x=f[x][i];
	}
	if(x==y) return x;
	for(int i=20;i>=0;i--)
	{
		if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	}
	return f[x][0];
}
inline int Dis(int x,int y)
{
	return dep[x]+dep[y]-2*dep[lca(x,y)];
}
void getmid(){
	int anc=lca(u,v);
	int len=Dis(u,v);
	int s=len/2;if(len&1==0) s--;
	if(dep[u]<dep[v]) swap(u,v);
	x=u;
	for(int k=20;k>=0;k--)
	{
		if(s<1<<k) continue; 
		x=f[x][k],s-=(1<<k);
	}	
	y=f[x][0];
}
int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d",&n);
	for(int i=1;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y);add(y,x);
	}
	dfs1(1,0);
	dfs2(1,0);
	scanf("%d",&m);
	for(int i=1;i<=m;i++){
		scanf("%d%d",&u,&v);
		getmid();
		ll ans=0;
		if(dep[x]<dep[y]) swap(x,y);
		ans+=dis[u]+up[u]-up[x]-(n-siz[x])*Dis(x,u);
		ans+=dis[v]+up[v]-dis[x]-(siz[x])*Dis(x,v);
		//printf("u=%lld v=%lld\n",dis[u]+up[u]-up[x]-(n-siz[x])*Dis(x,u),dis[v]+up[v]-dis[x]-(siz[x])*Dis(x,v));
		//printf("dis[u]=%d,dis[v]=%d,up[u]=%d,up[v]=%d\n",dis[u],dis[v],up[u],up[v]);
		printf("%lld\n",ans);
	}
	return 0;
}
posted @ 2021-08-18 00:12  conprour  阅读(24)  评论(0编辑  收藏  举报