树上差分
树上差分实质是O(1)打标记 通过o(n)dfs从下到上将标记回溯并更新到每个节点 将前缀和推广到树链的优秀的数据结构
2015NOIP
#include <bits/stdc++.h> #define ll long long #define f first #define s second #define pii pair<int,int> const int MAXN=3e5+10; using namespace std; ll read(){ ll x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch))x=x*10+ch-'0',ch=getchar(); return f*x; } vector<pii>vec[MAXN]; int fa[MAXN],dep[MAXN],num[MAXN],son[MAXN],vul[MAXN],n,m; ll dis[MAXN]; void dfs1(int v,int pre,int deep){ num[v]=1;fa[v]=pre;dep[v]=deep+1; for(int i=0;i<vec[v].size();i++){ if(vec[v][i].f!=pre){ vul[vec[v][i].f]=vec[v][i].s; dis[vec[v][i].f]=vec[v][i].s+dis[v]; dfs1(vec[v][i].f,v,deep+1); num[v]+=num[vec[v][i].f]; if(son[v]==-1||num[son[v]]<num[vec[v][i].f])son[v]=vec[v][i].f; } } } int p[MAXN],fp[MAXN],cnt,tp[MAXN]; void dfs2(int v,int td){ p[v]=++cnt;fp[p[v]]=v;tp[v]=td; if(son[v]!=-1)dfs2(son[v],td); for(int i=0;i<vec[v].size();i++){ if(vec[v][i].f!=fa[v]&&vec[v][i].f!=son[v])dfs2(vec[v][i].f,vec[v][i].f); } } int slove(int u,int v){ int uu=tp[u];int vv=tp[v]; while(uu!=vv){ if(dep[uu]<dep[vv])swap(uu,vv),swap(u,v); u=fa[uu];uu=tp[u]; } if(dep[u]>dep[v])swap(u,v); return u; } typedef struct node{ int lca,u,v;ll len; }node; node d[MAXN]; int P[MAXN]; void dfs3(int v,int pre){ for(int i=0;i<vec[v].size();i++){ if(pre!=vec[v][i].f){ dfs3(vec[v][i].f,v); P[v]+=P[vec[v][i].f]; } } } bool check(ll t){ int num=0;ll maxn=0; for(int i=1;i<=m;i++){ if(d[i].len>t){ num++;maxn=max(maxn,d[i].len); P[d[i].lca]-=2;P[d[i].u]++;P[d[i].v]++;} } //cout<<t<<" "<<num<<" "<<m<<endl; if(!num)return 1; int maxn1=0;dfs3(1,0); for(int i=1;i<=n;i++){ if(P[i]==num){ maxn1=max(maxn1,vul[i]); } } for(int i=1;i<=n;i++)P[i]=0; if(maxn-maxn1<=t)return 1; return 0; } int main(){ n=read();m=read(); int u,v,t; for(int i=1;i<=n;i++)son[i]=-1; for(int i=1;i<n;i++)u=read(),v=read(),t=read(),vec[u].push_back(make_pair(v,t)),vec[v].push_back(make_pair(u,t)); dfs1(1,0,0);dfs2(1,1); ll r=0; for(int i=1;i<=m;i++)d[i].u=read(),d[i].v=read(),d[i].lca=slove(d[i].u,d[i].v),d[i].len=dis[d[i].u]+dis[d[i].v]-2*dis[d[i].lca],r=max(r,d[i].len); ll l=0,ans=0; while(l<=r){ ll mid=(l+r)>>1; if(check(mid))ans=mid,r=mid-1; else l=mid+1; } printf("%lld\n",ans); return 0; }