J. Distance on the tree(树链剖分+线段树)
题:https://nanti.jisuanke.com/t/38229
题意:给定n个点m个询问。每个询问[x,,y,w]问x到y的路径上边权小于等于w的边数
离线处理,
运用树链剖分让LCA跑快点
关键是把n-1条边,和m条询问边存起来
然后按边权值W进行升序;
这样在计数询问的时候我们从小到大计数;
每条边只会被记一次且从小到大,这样就不用担心当前计数会受上一计数更新时的影响;
每次把小于等于当前查询的边加到链上;
查询就是查询链上有多少条边被加入过;
#include<bits/stdc++.h> using namespace std; const int M=1e5+10; inline int read(){ int sum=0,x=1; char ch=getchar(); while(ch<'0'||ch>'9'){ if(ch=='-') x=0; ch=getchar(); } while(ch>='0'&&ch<='9') sum=(sum<<1)+(sum<<3)+(ch^48),ch=getchar(); return x?sum:-sum; } int f[M],sz[M],deep[M],son[M],dfn[M],top[M],ans[M],t[M<<2],n,cnt; vector<int >graph[M]; struct node{ int u,v,w,index; bool operator<(const node &b)const{ return w<b.w; } }q[M],e[M]; void dfs1(int u,int from){ f[u]=from; sz[u]=1; deep[u]=deep[from]+1; for(int i=0;i<graph[u].size();i++){ int v=graph[u][i]; if(v!=from){ dfs1(v,u); sz[u]+=sz[v]; if(sz[v]>sz[son[u]]) son[u]=v; } } } void dfs2(int u,int t){ top[u]=t; dfn[u]=++cnt; if(!son[u]) return ; dfs2(son[u],t); for(int i=0;i<graph[u].size();i++){ int v=graph[u][i]; if(v!=son[u]&&v!=f[u]){ dfs2(v,v); } } } void update(int sign,int c,int root,int l,int r){ if(l==r){ t[root]+=c; return ; } int midd=l+r>>1; if(sign<=midd) update(sign,c,root<<1,l,midd); else update(sign,c,root<<1|1,midd+1,r); t[root]=t[root<<1]+t[root<<1|1]; } int find(int L,int R,int root,int l,int r){ if(L<=l&&r<=R) return t[root]; int midd=l+r>>1; int c=0; if(L<=midd) c+=find(L,R,root<<1,l,midd); if(R>midd) c+=find(L,R,root<<1|1,midd+1,r); return c; } int solve(int u,int v){ int c=0; int fu=top[u],fv=top[v]; while(fu!=fv){ if(deep[fu]>=deep[fv]){ c+=find(dfn[fu],dfn[u],1,1,n); u=f[fu],fu=top[u]; } else{ c+=find(dfn[fv],dfn[v],1,1,n); v=f[fv],fv=top[v]; } } if(dfn[u]<dfn[v]) c+=find(dfn[u]+1,dfn[v],1,1,n); else if(dfn[u]>dfn[v]) c+=find(dfn[v]+1,dfn[u],1,1,n); return c; } int main(){ n=read(); int m=read(); for(int i=1;i<n;i++){ int x=read(),y=read(),w=read(); e[i].u=x,e[i].v=y,e[i].w=w; graph[x].push_back(y); graph[y].push_back(x); } dfs1(1,1); dfs2(1,1); /*cout<<"~~~~~~~~~"; for(int i=1;i<=n;i++) cout<<dfn[i]<<" "; cout<<endl;*/ for(int i=1;i<=m;i++){ int x=read(),y=read(),w=read(); q[i].u=x,q[i].v=y,q[i].w=w; q[i].index=i; } sort(e+1,e+n); sort(q+1,q+1+m); int cur=1; for(int i=1;i<=m;i++){ while(cur<n&&e[cur].w<=q[i].w){ int u=e[cur].u,v=e[cur].v; if(deep[u]<deep[v]) swap(u,v); update(dfn[u],1,1,1,n); cur++; } ans[q[i].index]+=solve(q[i].u,q[i].v); } for(int i=1;i<=m;i++) printf("%d\n",ans[i]); return 0; }