基本思路:
- 题目要求树直径的必经边,那么首先应当获取一条直径.
- 获取直径后从直径上的两个端点分别遍历一次直径,每次遍历直径时从直径上的每个点分别dfs一次并不经过直径上的点,如果深度可以被替换则说明非必经边.
#include<cstdio> #include<iostream> #include<queue> #include<cstring> #define ll long long using namespace std; const int MAXN=2e5+10,MAXM=MAXN*2; struct Edge{ ll from,to,w,nxt; }e[MAXM]; int head[MAXN],edgeCnt=1; void addEdge(ll u,ll v,ll w){ e[++edgeCnt].from=u; e[edgeCnt].to=v; e[edgeCnt].w=w; e[edgeCnt].nxt=head[u]; head[u]=edgeCnt; } ll dep[MAXN]; bool vis[MAXN]; int n; int from[MAXN]; bool inDiameter[MAXN];//是否在直径上 ll st,ed;//直径端点 void markDiameter(){//标记直径 int tmp=ed; while(tmp){ inDiameter[tmp]=1; tmp=from[tmp]; } } ll bfs(int s){//求直径 memset(dep,0,sizeof(dep)); memset(vis,0,sizeof(vis)); memset(from,0,sizeof(from)); queue<int> q; q.push(s); vis[s]=1; while(!q.empty()){ int nowU=q.front(); q.pop(); for(int i=head[nowU];i;i=e[i].nxt){ int nowV=e[i].to; if(!vis[nowV]){ dep[nowV]=dep[nowU]+e[i].w; vis[nowV]=1; from[nowV]=nowU; q.push(nowV); } } } ll ans=0,ansDep=0; for(int i=1;i<=n;i++){ if(dep[i]>ansDep){ ans=i,ansDep=dep[i]; } } return ans; } ll dis_fromST[MAXN];//每个点到直径st端点的距离 bool vis_getDisFromST[MAXN]; void bfs_getDisFromST(){ queue<int> q; q.push(st); vis_getDisFromST[st]=1; while(!q.empty()){ int nowU=q.front(); q.pop(); for(int i=head[nowU];i;i=e[i].nxt){ int nowV=e[i].to; if(!vis_getDisFromST[nowV]){ vis_getDisFromST[nowV]=1; dis_fromST[nowV]=dis_fromST[nowU]+e[i].w; q.push(nowV); } } } } ll dis_fromED[MAXN];//每个点到直径ed端点的距离 bool vis_getDisFromED[MAXN]; void bfs_getDisFromED(){ queue<int> q; q.push(ed); vis_getDisFromED[ed]=1; while(!q.empty()){ int nowU=q.front(); q.pop(); for(int i=head[nowU];i;i=e[i].nxt){ int nowV=e[i].to; if(!vis_getDisFromED[nowV]){ vis_getDisFromED[nowV]=1; dis_fromED[nowV]=dis_fromED[nowU]+e[i].w; q.push(nowV); } } } } ll dep_noDiameter[MAXN];//不经过直径的最大深度 ll dfs_noDiameter(int x,int fa){ ll maxDep=dep_noDiameter[x]; for(int i=head[x];i;i=e[i].nxt){ int nowV=e[i].to; if(nowV==fa||inDiameter[nowV])continue; dep_noDiameter[nowV]=dep_noDiameter[x]+e[i].w; ll tmp=dfs_noDiameter(nowV,x); maxDep=max(maxDep,tmp); } return maxDep; } int main(){ scanf("%d",&n); for(int i=1;i<=n-1;i++){ ll a,b,c; cin>>a>>b>>c; addEdge(a,b,c); addEdge(b,a,c); } st=bfs(1); ed=bfs(st);//直径 cout<<dep[ed]<<endl; markDiameter();//标记直径上点 bfs_getDisFromST(); bfs_getDisFromED();//获取每个点到两个端点的距离 int l=st,r=ed;//必经边端点 int tmp=from[ed]; while(tmp){//r if(tmp==st)break; dep_noDiameter[tmp]=0; ll nowMaxDep=dfs_noDiameter(tmp,0); if(nowMaxDep==dis_fromED[tmp]){ r=tmp; } tmp=from[tmp]; } bfs(ed);//重新获取from数组 tmp=from[st]; while(tmp){//l if(tmp==ed||tmp==r)break; ll nowMaxDep=dfs_noDiameter(tmp,0); if(nowMaxDep==dis_fromST[tmp])l=tmp; tmp=from[tmp]; } int cnt=0; tmp=l; while(tmp){ if(tmp==r)break; cnt++; tmp=from[tmp]; } printf("%d\n",cnt); return 0; }