题解 树上路径

传送门

有一个暴力做法是枚举一条边断开,在形成的两个连通块中求直径更新答案
于是树形DP预处理可以做到 \(O(1)\) 求直径
整体复杂度 \(O(n)\)

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 500010
#define ll long long
#define fir first
#define sec second
#define make make_pair
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int head[N], size;
struct edge{int from, to, next;}e[N<<1];
inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}

#if 0
namespace force{
	int dep[N], fa[24][N], lg[N];
	bool vis[N];
	void dfs1(int u, int pa) {
		for (int i=1; i<=20; ++i)
			if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
			else break;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=pa) {
				dep[v]=dep[u]+1; fa[0][v]=u;
				dfs1(v, u);
			}
		}
	}
	int lca(int a, int b) {
		if (dep[a]<dep[b]) swap(a, b);
		while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
		if (a==b) return a;
		for (int i=lg[dep[a]]-1; ~i; --i) 
			if (fa[i][a]!=fa[i][b])
				a=fa[i][a], b=fa[i][b];
		return fa[0][a];
	}
	int paint(int u, int fa, int to) {
		if (u==to) {vis[u]=1; return 1;}
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) {
				int t=paint(v, u, to);
				if (t) {vis[u]=1; return 1;}
			}
		}
		return 0;
	}
	int anycol(int u, int fa, int to, bool& tag) {
		if (u==to) {tag|=vis[u]; return 1;}
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) {
				int t=anycol(v, u, to, tag);
				if (t) {tag|=vis[u]; return 1;}
			}
		}
		return 0;
	}
	int dis(int a, int b) {return dep[a]+dep[b]-2*dep[lca(a, b)];}
	bool check(int x, int y) {
		cout<<"check: "<<x<<' '<<y<<endl;
		for (int i=1; i<=n; ++i) {
			for (int j=1; j<=n; ++j) {
				if (dis(i, j)!=x) continue;
				memset(vis, 0, sizeof(vis));
				paint(i, 0, j);
				for (int k=1; k<=n; ++k) {
					for (int l=1; l<=n; ++l) {
						if (dis(k, l)!=y) continue;
						bool tag=0;
						anycol(k, 0, l, tag);
						if (tag) continue;
						return 1;
					}
				}
			}
		}
		cout<<"return 0"<<endl;
		return 0;
	}
	void solve() {
		for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
		dep[1]=1; dfs1(1, 0);
		int ans=0;
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) if (check(i, j)) ++ans;
		cout<<ans<<endl;
		exit(0);
	}
}
#endif

namespace task1{
	int f[N], g[N], k[N], h[N], dep[N], ans[N];
	pair<int, int> fir[N], sec[N], thr[N];
	void dfs1(int u, int fa) {
		f[u]=k[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) {
				dep[v]=dep[u]+1;
				dfs1(v, u);
				f[u]=max(f[u], f[v]+1);
				k[u]=max(k[u], f[v]+1);
				k[u]=max(k[u], k[v]);
				if (f[v]>=fir[u].fir) thr[u]=sec[u], sec[u]=fir[u], fir[u]=make(f[v], v);
				else if (f[v]>=sec[u].fir) thr[u]=sec[u], sec[u]=make(f[v], v);
				else if (f[v]>thr[u].fir) thr[u]=make(f[v], v);
			}
		}
		if (fir[u].fir&&sec[u].fir) k[u]=max(k[u], fir[u].fir+sec[u].fir+1);
	}
	void dfs2(int u, int fa, int s) {
		g[u]=s;
		if (fa) {
			int a[5], tot=0;
			if (g[fa]) a[++tot]=g[fa];
			if (fir[fa].sec!=u) a[++tot]=fir[fa].fir+1;
			if (sec[fa].sec!=u) a[++tot]=sec[fa].fir+1;
			if (thr[fa].sec!=u) a[++tot]=thr[fa].fir+1;
			sort(a+1, a+tot+1, [](int a, int b){return a>b;});
			if (tot==1) h[u]=max(h[fa], a[1]);
			else if (tot>1) h[u]=max(h[fa], a[1]+a[2]-1);
		}
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=fa) {
				if (fir[u].sec==v) dfs2(v, u, max(g[u]+1, sec[u].fir+2));
				else dfs2(v, u, max(g[u]+1, fir[u].fir+2));
			}
		}
	}
	void solve() {
		dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1);
		#if 0
		cout<<"f: "; for (int i=1; i<=n; ++i) cout<<f[i]<<' '; cout<<endl;
		cout<<"g: "; for (int i=1; i<=n; ++i) cout<<g[i]<<' '; cout<<endl;
		cout<<"k: "; for (int i=1; i<=n; ++i) cout<<k[i]<<' '; cout<<endl;
		cout<<"h: "; for (int i=1; i<=n; ++i) cout<<h[i]<<' '; cout<<endl;
		#endif

		for (int i=1,u,v; i<=size; i+=2) {
			u=e[i].from; v=e[i].to;
			if (dep[u]>dep[v]) swap(u, v);
			int t1=h[v], t2=k[v];
			ans[t1]=max(ans[t1], t2);
			ans[t2]=max(ans[t2], t1);
		}
		for (int i=n; i; --i) ans[i]=max(ans[i], ans[i+1]);
		// cout<<"ans: "; for (int i=1; i<=n; ++i) cout<<ans[i]<<' '; cout<<endl;
		ll sum=0;
		for (int i=1; i<=n; ++i) sum+=ans[i];
		printf("%lld\n", sum);
		exit(0);
	}
}

signed main()
{
	freopen("tree.in", "r", stdin);
	freopen("tree.out", "w", stdout);

	n=read();
	memset(head, -1, sizeof(head));
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
	}
	task1::solve();
	// force::solve();

	return 0;
}
posted @ 2021-11-11 20:38  Administrator-09  阅读(0)  评论(0编辑  收藏  举报