P1600 天天爱跑步(线段树合并+树上差分)
线段树合并+树上差分
个人认为这道题线段树合并的做法是最简单的
题意:给一棵树和 m条路径,树上每个点有一个值 \(W_i\) 。问对于每一个点,询问有多少条路径的第 \(W_i+1\)个点是这个点。\(n,m \leqslant 1e5\)
假设当前路径为\(s\)->\(t\),显然我们可以预处理出\(lca\)
我们考虑对于每一个节点建立一个权值线段树,以\(dep\)为下标,每个节点保存数值i的出现次数,其实我们只需要叶子结点上的信息
用树上差分把链的信息转化为点
现在对于\(s\)->\(lca\)的路径,只需要在\(s\)处的线段树让\(dep[s]\)+1,然后每个点统计答案时查询\(dep[x]+w[x]\)出现了几次即可即可
对于\(lca\)->\(t\)的路径,我们想办法把\(s\)关于\(lca\)翻上去,在在每个点统计dep=dep[x]-w[x]的点有几个
注意翻上去后d可能变为负的,所以要整体平移一下值域,都加上n即可
/*
天天爱跑步 2019.7.10
给一棵树和 m条路径,树上每个点有一个值 Wi 。问对于每一个点,有多少条路径的第 Wi+1个点是这个点。
n,m:3e5
*/
#include <cstdio>
#include <cctype>
#include <iostream>
#include <algorithm>
using namespace std;
#define rint register int
#define il inline
const int N=3e5+5;
il int read(int x=0,int f=1,char ch='0')
{
while(!isdigit(ch=getchar())) if(ch=='-') f=-1;
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return f*x;
}
int n,m,w[N];
int head[N],ver[N<<1],nxt[N<<1],tot;
il void add(int x,int y){ ver[++tot]=y; nxt[tot]=head[x]; head[x]=tot; }
int dep[N],top[N],son[N],siz[N],fa[N];
void dfs1(int x,int ff)
{
fa[x]=ff; dep[x]=dep[ff]+1; siz[x]=1;
for(rint i=head[x];i;i=nxt[i])
{
int y=ver[i]; if(y==ff) continue;
dfs1(y,x); siz[x]+=siz[y];
if(siz[y]>siz[son[x]]) son[x]=y;
}
}
void dfs2(int x,int topf)
{
top[x]=topf;
if(!son[x]) return ;
dfs2(son[x],topf);
for(rint i=head[x];i;i=nxt[i])
{
int y=ver[i]; if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
il int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return dep[x]>dep[y] ? y : x;
}
const int M=N*55;
int root[N],lc[M],rc[M],val[M],num;
void update(int &x,int l,int r,int v,int d)
{
if(!x) x=++num;
if(l==r) return (void)(val[x]+=d);
int mid=l+r>>1;
if(v<=mid) update(lc[x],l,mid,v,d);
else update(rc[x],mid+1,r,v,d);
}
int query(int x,int l,int r,int p)
{
if(!x) return 0;
if(l==r) return val[x];
int mid=l+r>>1;
if(p<=mid) return query(lc[x],l,mid,p);
else return query(rc[x],mid+1,r,p);
}
int merge(int a,int b,int l,int r)
{
if(!a || !b) return a+b;
if(l==r)
val[a]+=val[b];
else
{
int mid=l+r>>1;
lc[a]=merge(lc[a],lc[b],l,mid);
rc[a]=merge(rc[a],rc[b],mid+1,r);
}
return a;
}
/*
对于s->lca的路径,只需要在s处让dep[s]+1,然后每个点统计dep=dep[x]+w[x]的点有几个即可
对于lca->t的路径,我们想办法把s关于lca翻上去,在每个点统计dep=dep[x]-w[x]的点有几个
直接树上差分即可
注意翻上去后d可能变为负的,所以要整体平移一下
*/
int ans[N];
void dfs(int x)
{
for(rint i=head[x];i;i=nxt[i])
{
int y=ver[i]; if(y==fa[x]) continue;
dfs(y);
root[x]=merge(root[x],root[y],1,n<<1);
}
if(w[x] && n+dep[x]+w[x]<=2*n)//注意要判断没有越界
ans[x]+=query(root[x],1,n<<1,n+dep[x]+w[x]);
ans[x]+=query(root[x],1,n<<1,n+dep[x]-w[x]);
}
int main()
{
n=read(); m=read();
for(rint i=1,x,y;i<n;++i) x=read(),y=read(),add(x,y),add(y,x);
dfs1(1,0); dfs2(1,1);
for(rint i=1;i<=n;++i) w[i]=read();
for(rint i=1;i<=m;++i)
{
int x=read(),y=read(); int lca=LCA(x,y);
update(root[x],1,n<<1,n+dep[x],1);
update(root[y],1,n<<1,n+dep[lca]*2-dep[x],1);
update(root[lca],1,n<<1,n+dep[x],-1);
update(root[fa[lca]],1,n<<1,n+dep[lca]*2-dep[x],-1);
}
dfs(1);
for(rint i=1;i<=n;++i)
printf("%d ",ans[i]);
return 0;
}