题解 树上路径
有一个暴力做法是枚举一条边断开,在形成的两个连通块中求直径更新答案
于是树形DP预处理可以做到 \(O(1)\) 求直径
整体复杂度 \(O(n)\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 500010
#define ll long long
#define fir first
#define sec second
#define make make_pair
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int head[N], size;
struct edge{int from, to, next;}e[N<<1];
inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}
#if 0
namespace force{
int dep[N], fa[24][N], lg[N];
bool vis[N];
void dfs1(int u, int pa) {
for (int i=1; i<=20; ++i)
if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v!=pa) {
dep[v]=dep[u]+1; fa[0][v]=u;
dfs1(v, u);
}
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
if (a==b) return a;
for (int i=lg[dep[a]]-1; ~i; --i)
if (fa[i][a]!=fa[i][b])
a=fa[i][a], b=fa[i][b];
return fa[0][a];
}
int paint(int u, int fa, int to) {
if (u==to) {vis[u]=1; return 1;}
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v!=fa) {
int t=paint(v, u, to);
if (t) {vis[u]=1; return 1;}
}
}
return 0;
}
int anycol(int u, int fa, int to, bool& tag) {
if (u==to) {tag|=vis[u]; return 1;}
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v!=fa) {
int t=anycol(v, u, to, tag);
if (t) {tag|=vis[u]; return 1;}
}
}
return 0;
}
int dis(int a, int b) {return dep[a]+dep[b]-2*dep[lca(a, b)];}
bool check(int x, int y) {
cout<<"check: "<<x<<' '<<y<<endl;
for (int i=1; i<=n; ++i) {
for (int j=1; j<=n; ++j) {
if (dis(i, j)!=x) continue;
memset(vis, 0, sizeof(vis));
paint(i, 0, j);
for (int k=1; k<=n; ++k) {
for (int l=1; l<=n; ++l) {
if (dis(k, l)!=y) continue;
bool tag=0;
anycol(k, 0, l, tag);
if (tag) continue;
return 1;
}
}
}
}
cout<<"return 0"<<endl;
return 0;
}
void solve() {
for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
dep[1]=1; dfs1(1, 0);
int ans=0;
for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) if (check(i, j)) ++ans;
cout<<ans<<endl;
exit(0);
}
}
#endif
namespace task1{
int f[N], g[N], k[N], h[N], dep[N], ans[N];
pair<int, int> fir[N], sec[N], thr[N];
void dfs1(int u, int fa) {
f[u]=k[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v!=fa) {
dep[v]=dep[u]+1;
dfs1(v, u);
f[u]=max(f[u], f[v]+1);
k[u]=max(k[u], f[v]+1);
k[u]=max(k[u], k[v]);
if (f[v]>=fir[u].fir) thr[u]=sec[u], sec[u]=fir[u], fir[u]=make(f[v], v);
else if (f[v]>=sec[u].fir) thr[u]=sec[u], sec[u]=make(f[v], v);
else if (f[v]>thr[u].fir) thr[u]=make(f[v], v);
}
}
if (fir[u].fir&&sec[u].fir) k[u]=max(k[u], fir[u].fir+sec[u].fir+1);
}
void dfs2(int u, int fa, int s) {
g[u]=s;
if (fa) {
int a[5], tot=0;
if (g[fa]) a[++tot]=g[fa];
if (fir[fa].sec!=u) a[++tot]=fir[fa].fir+1;
if (sec[fa].sec!=u) a[++tot]=sec[fa].fir+1;
if (thr[fa].sec!=u) a[++tot]=thr[fa].fir+1;
sort(a+1, a+tot+1, [](int a, int b){return a>b;});
if (tot==1) h[u]=max(h[fa], a[1]);
else if (tot>1) h[u]=max(h[fa], a[1]+a[2]-1);
}
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v!=fa) {
if (fir[u].sec==v) dfs2(v, u, max(g[u]+1, sec[u].fir+2));
else dfs2(v, u, max(g[u]+1, fir[u].fir+2));
}
}
}
void solve() {
dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1);
#if 0
cout<<"f: "; for (int i=1; i<=n; ++i) cout<<f[i]<<' '; cout<<endl;
cout<<"g: "; for (int i=1; i<=n; ++i) cout<<g[i]<<' '; cout<<endl;
cout<<"k: "; for (int i=1; i<=n; ++i) cout<<k[i]<<' '; cout<<endl;
cout<<"h: "; for (int i=1; i<=n; ++i) cout<<h[i]<<' '; cout<<endl;
#endif
for (int i=1,u,v; i<=size; i+=2) {
u=e[i].from; v=e[i].to;
if (dep[u]>dep[v]) swap(u, v);
int t1=h[v], t2=k[v];
ans[t1]=max(ans[t1], t2);
ans[t2]=max(ans[t2], t1);
}
for (int i=n; i; --i) ans[i]=max(ans[i], ans[i+1]);
// cout<<"ans: "; for (int i=1; i<=n; ++i) cout<<ans[i]<<' '; cout<<endl;
ll sum=0;
for (int i=1; i<=n; ++i) sum+=ans[i];
printf("%lld\n", sum);
exit(0);
}
}
signed main()
{
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
n=read();
memset(head, -1, sizeof(head));
for (int i=1,u,v; i<n; ++i) {
u=read(); v=read();
add(u, v); add(v, u);
}
task1::solve();
// force::solve();
return 0;
}