写的巨慢。。。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define maxv 300500 #define maxe 600500 using namespace std; struct edge { int v,w,nxt; }e[maxe]; struct plans { int u,v,t,d; }p[maxv]; int n,m,x,y,z,g[maxv],nume=0,dis[maxv][21],anc[maxv][21],dep[maxv],fath_e[maxv]; int d=0,t=0,stack[maxv],top=0,val[maxv]; void addedge(int u,int v,int w) {e[++nume].v=v;e[nume].w=w;e[nume].nxt=g[u];g[u]=nume;} int read() { int data=0;char ch; while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9') { data=data*10+ch-'0'; ch=getchar(); } return data; } void dfs1(int x) { for (int i=g[x];i;i=e[i].nxt) { int v=e[i].v; if (v!=anc[x][0]) { anc[v][0]=x;dis[v][0]=e[i].w; dep[v]=dep[x]+1;fath_e[v]=i; dfs1(v); } } } void dfs2(int x) { for (int i=g[x];i;i=e[i].nxt) { int v=e[i].v; if (anc[x][0]!=v) { dfs2(v); val[x]+=val[v]; } } } void get_dis(int x) { int u=p[x].u,v=p[x].v; d=0;t=1; if (dep[u]<dep[v]) {swap(u,v);swap(p[x].u,p[x].v);} for (int e=20;e>=0;e--) { if ((dep[anc[u][e]]>=dep[v]) && (anc[u][e])) { d+=dis[u][e]; u=anc[u][e]; } } if (u==v) {t=u;return;} for (int e=20;e>=0;e--) { if (anc[u][e]!=anc[v][e]) { d+=dis[u][e]+dis[v][e]; u=anc[u][e];v=anc[v][e]; } } t=anc[u][0];d+=dis[u][0]+dis[v][0]; return; } void get_table1() { for (int e=1;e<=20;e++) for (int i=1;i<=n;i++) { anc[i][e]=anc[anc[i][e-1]][e-1]; dis[i][e]=dis[i][e-1]+dis[anc[i][e-1]][e-1]; } } void get_table2() { for (int i=1;i<=m;i++) { get_dis(i); p[i].t=t;p[i].d=d; } } void get_labled() { memset(val,0,sizeof(val)); for (int i=1;i<=top;i++) { int x=stack[i]; if (p[x].t==p[x].v) val[p[x].v]--,val[p[x].u]++; else val[p[x].t]-=2,val[p[x].u]++,val[p[x].v]++; } } bool check(int x) { top=0; for (int i=1;i<=m;i++) if (p[i].d>x) stack[++top]=i; get_labled(); dfs2(1); int mx=0; for (int i=1;i<=n;i++) { if (val[i]==top) mx=max(mx,e[fath_e[i]].w); } for (int i=1;i<=top;i++) { if (p[stack[i]].d-mx>x) return false; } return true; } int half_search() { int l=1,r=(n+1)*1000,ans; while (l<=r) { int mid=l+r>>1; if (check(mid)) {ans=mid;r=mid-1;} else l=mid+1; } return ans; } int main() { n=read();m=read(); for (int i=1;i<=n-1;i++) { x=read();y=read();z=read(); addedge(x,y,z);addedge(y,x,z); } for (int i=1;i<=m;i++) p[i].u=read(),p[i].v=read(); dfs1(1); get_table1(); get_table2(); printf("%d\n",half_search()); return 0; }