【SDOI2013】JZOJ8月3日提高组T4 直径
题目
题目描述
小 Q 最近学习了一些图论知识。根据课本,有如下定义。
树:无回路且连通的无向图,每条边都有正整数的权值来表示其长度。如果一棵树有 N 个节点,可以证明其有且仅有 N-1 条边。
路径:一棵树上,任意两个节点之间最多有一条简单路径。我们用 dis(a,b)表示点 a 和点 b 的路径上各边长度之和。称 dis(a,b)为 a、b 两个节点间的距离。
直径:一棵树上,最长的路径为树的直径。树的直径可能不是唯一的。
现在小 Q 想知道,对于给定的一棵树,其直径的长度是多少,以及有多少条边满足所有的直径都经过该边。
数据范围
对于 20%的测试数据:N≤100
对于 40%的测试数据:N≤1000
对于 70%的测试数据:N≤100000
对于 100%的测试数据:2≤N≤200000,所有点的编号都在 1..N 的范围内,边的权值≤10^9。
对于每个测试点,若输出文件的第一行与标准输出相同,则得到该测试点20%的分数,若输出文件的第二行与标准输出相同,则得到该测试点 80%的分数,两项可累加。
本题使用自定义校验器,为防止自定义校验器出错,即使你无法正确得出某一问的答案,也应在相应的位置随便输出一个数字。
题解
题意
给出一个树,求直径和有多少条边在所有的直径上
分析
对于直径,可以先随便找一个点(例如1),然后找到离它最远的那个点(x),然后再去找离x最远的那个点(y),那么x和y之间的距离就是直径
具体证明:传送门
然后把直径拉出来
对于直径上的每个点,记录从x到它的长度(设为len),并求出以它为根的子树距离它最远的点到它的长度(设为dis),并记录有多少个(设为num)
当len=dis的时候,答案直接减去deep(从0开始)后输出
若不是,当num>1,答案变为deep
Code
#include<cstdio> #include<cstring> using namespace std; struct node { long long head,to,next,val; }a[400005]; long long n,i,x,y,z,root,s,ans,tot,deep[200005],f[200005],dis[200005],far[200005]; bool end,b[200005],bb[200005]; void add(long long x,long long y,long long z) { tot++; a[tot].to=y; a[tot].val=z; a[tot].next=a[x].head; a[x].head=tot; } void get(long long now,long long fa) { long long i,x; for (i=a[now].head;i;i=a[i].next) { x=a[i].to; if (x!=fa) { dis[x]=dis[now]+a[i].val; get(x,now); } } } void dfs1(long long now) { long long i; for (i=a[now].head;i;i=a[i].next) { long long x=a[i].to; if (x!=f[now]) { f[x]=now; deep[x]=deep[now]+1; dis[x]=dis[now]+a[i].val; dfs1(x); } } } void dfs2(long long now) { if (end==true) return; long long i,x,num; long long mx; bool bz; num=0; mx=-1; bz=false; for (i=a[now].head;i;i=a[i].next) { x=a[i].to; if (x==f[now]) continue; dfs2(x); if (far[x]>=mx) { if (far[x]>mx) { mx=far[x]; num=1; } else num++; } if (far[x]-dis[now]==dis[now]) bz=true; } if (end==true) return; if (b[now]==true) { if (bz==true) { ans-=deep[now]-1; if (ans<0) ans=0; printf("%lld\n",ans); end=true; return; } else { if (num>1) ans=deep[now]-1; } } if (mx<dis[now]) mx=dis[now]; far[now]=mx; } int main() { scanf("%lld",&n); for (i=1;i<n;i++) { scanf("%lld%lld%lld",&x,&y,&z); add(x,y,z); add(y,x,z); } get(1,0); root=1; for (i=2;i<=n;i++) if (dis[i]>dis[root]) root=i; memset(deep,0,sizeof(deep)); memset(f,0,sizeof(f)); memset(dis,0,sizeof(dis)); deep[root]=1; f[root]=0; bb[root]=true; dfs1(root); for (i=1;i<=n;i++) if (dis[i]>dis[s]) s=i; printf("%lld\n",dis[s]); ans=deep[s]-1; x=s; while (x!=root) { b[x]=true; x=f[x]; } dfs2(root); if (end==false) printf("%lld\n",ans); return 0; }