P1600 天天爱跑步
考虑一个玩家的路径 $(x,y)$ 对路径上的一个节点 $u$ 的贡献
设 $lca=LCA(x,y)$ ,当 $u$ 在链 $x,lca$ 上时,路径会产生 $1$ 的贡献当且仅当 $dep[x]-dep[u]=w[u]$
其中 $dep[i]$ 表示节点 $i$ 的深度,$w[i]$ 就是题目给出的 $W$,上式即 $dep[x]=dep[u]+w[u]$
当 $u$ 在链 $lca,y$ 上时,路径会产生贡献当且仅当 $dep[y]-dep[u]=dis(x,y)-w[u]$
其中 $dis(x,y)$ 表示节点 $x,y$ 的路径长度,上式即 $dep[y]-dep[u]=dep[x]+dep[y]-2dep[lca]-w[u]$
即 $-dep[x]+2dep[lca]=dep[u]-w[u]$
直接树链剖分并对每种深度维护动态开点线段树,分别维护上面两种情况,对于一条路径 $(x,y)$ 直接把对应深度的线段树的, $x$ 到 $y$ 的所有节点$+1$
即深度为 $dep[x]$ 和深度为 $-dep[x]+2dep[lca]$ ,注意这是两种情况,对每种深度都要两颗线段树分别维护
因为有第二种情况的深度可能有负数,所以要把深度加 $N$ 转成正的,变成 $2dep[lca]-dep[x]+N$,并注意这样搞 $lca$ 会被算两次,要减一次
查询节点 $u$ 时就查询深度为 $dep[u]+w[u]$ 和 $dep[u]-w[u]+N$ (注意$+N$)的线段树上节点 $u$ 的值就好了
线段树的最大节点数要注意算好,代码中线段树用标记永久化,又好写速度又快
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #include<map> using namespace std; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } const int N=3e5+7; int n,m,w[N]; int fir[N],from[N<<1],to[N<<1],cntt; inline void add(int a,int b) { from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b; } int rt1[N<<1],rt2[N<<1],T[N*40],L[N*40],R[N*40],Tag[N*40],cnt; //rt1的线段树维护u在(x,lca)的情况,rt2的线段树维护u在(lca,y)的情况 void ins(int &o,int l,int r,int ql,int qr,int K) { if(!o) o=++cnt; if(l>=ql&&r<=qr) { Tag[o]+=K; return; } int mid=l+r>>1; if(ql<=mid) ins(L[o],l,mid,ql,qr,K); if(mid<qr) ins(R[o],mid+1,r,ql,qr,K); } int query(int &o,int l,int r,int pos) { if(!o) return 0; if(l==r) return Tag[o]; int mid=l+r>>1; return ( pos<=mid ? query(L[o],l,mid,pos) : query(R[o],mid+1,r,pos) ) + Tag[o]; } int son[N],Fa[N],sz[N],dep[N],Top[N],id[N],dfs_clock; void dfs1(int x) { sz[x]=1; for(int i=fir[x];i;i=from[i]) { int &v=to[i]; if(v==Fa[x]) continue; Fa[v]=x; dep[v]=dep[x]+1; dfs1(v); sz[x]+=sz[v]; if(sz[v]>sz[son[x]]) son[x]=v; } } void dfs2(int x,int tp) { id[x]=++dfs_clock; Top[x]=tp; if(son[x]) dfs2(son[x],tp); for(int i=fir[x];i;i=from[i]) { int &v=to[i]; if(v==Fa[x]||v==son[x]) continue; dfs2(v,v); } } int LCA(int x,int y) { for(;Top[x]!=Top[y];x=Fa[Top[x]]) if(dep[Top[x]]<dep[Top[y]]) swap(x,y); return dep[x]<dep[y] ? x : y; } //dep[x]-dep[u]==w[u] dep[y]-dep[u]==dep[x]+dep[y]-2dep[lca]-w[u] //dep[x]==dep[u]+w[u] -dep[x]+2dep[lca]==dep[u]-w[u] void work(int x,int y) { int lca=LCA(x,y),p; for(p=x;Top[p]!=Top[lca];p=Fa[Top[p]]) ins(rt1[dep[x]],1,n,id[Top[p]],id[p],1); ins(rt1[dep[x]],1,n,id[lca],id[p],1); for(p=y;Top[p]!=Top[lca];p=Fa[Top[p]]) ins(rt2[2*dep[lca]-dep[x]+N],1,n,id[Top[p]],id[p],1); ins(rt2[2*dep[lca]-dep[x]+N],1,n,id[lca],id[p],1); ins(rt1[dep[x]],1,n,id[lca],id[lca],-1);//注意lca被算了两次 } int main() { n=read(),m=read(); int a,b; for(int i=1;i<n;i++) { a=read(),b=read(); add(a,b); add(b,a); } for(int i=1;i<=n;i++) w[i]=read(); dep[1]=1; dfs1(1); dfs2(1,1); for(int i=1;i<=m;i++) a=read(),b=read(),work(a,b); for(int i=1;i<=n;i++) { int ans1=query(rt1[dep[i]+w[i]],1,n,id[i]); int ans2=query(rt2[dep[i]-w[i]+N],1,n,id[i]); printf("%d ",ans1+ans2); } return 0; }