C69 线段树合并+树上差分 P1600 [NOIP2016 提高组] 天天爱跑步
视频链接:255 线段树合并+树上差分 P1600 [NOIP2016 提高组] 天天爱跑步_哔哩哔哩_bilibili
Luogu P1600 [NOIP2016 提高组] 天天爱跑步
#include <iostream> #include <cstring> #include <algorithm> #include <vector> using namespace std; void read(int &x){ //快读 x=0; char c=getchar(); while(!isdigit(c))c=getchar(); while(isdigit(c))x=x*10+c-'0',c=getchar(); } const int N=300005; #define mid (l+r)/2 int n,m,w[N],ans[N]; vector<int> g[N]; //邻接表 int fa[N][20],dep[N]; //树增 int root[N],tot; //线段树的根 int ls[N*55],rs[N*55],sum[N*55]; //sum[u]:深度u出现的次数 void dfs(int x,int f){ //树增 dep[x]=dep[f]+1; fa[x][0]=f; for(int i=1; i<=18; i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for(int y:g[x]) if(y!=f) dfs(y,x); } int lca(int x,int y){ //求lca if(dep[x]<dep[y]) swap(x,y); for(int i=18; ~i; i--) if(dep[fa[x][i]]>=dep[y])x=fa[x][i]; if(x==y) return y; for(int i=18; ~i; i--) if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; return fa[x][0]; } void change(int &u,int l,int r,int p,int k){ //点修 if(!u) u=++tot; if(l==r){sum[u]+=k; return;} if(p<=mid) change(ls[u],l,mid,p,k); else change(rs[u],mid+1,r,p,k); } int merge(int x,int y,int l,int r){ //合并 if(!x||!y) return x+y; if(l==r){sum[x]+=sum[y]; return x;} ls[x]=merge(ls[x],ls[y],l,mid); rs[x]=merge(rs[x],rs[y],mid+1,r); return x; } int query(int u,int l,int r,int p){ //点查 if(l==r) return sum[u]; if(p<=mid)return query(ls[u],l,mid,p); else return query(rs[u],mid+1,r,p); } void dfs2(int x){ //合并与统计 for(int y:g[x]){ if(y==fa[x][0]) continue; dfs2(y); root[x]=merge(root[x],root[y],1,n<<1); } if(w[x] && n+dep[x]+w[x]<=n<<1) ans[x]+=query(root[x],1,n<<1,n+dep[x]+w[x]); ans[x]+=query(root[x],1,n<<1,n+dep[x]-w[x]); } int main(){ read(n);read(m); int x,y; for(int i=1;i<n;++i){ read(x),read(y); g[x].push_back(y); g[y].push_back(x); } for(int i=1;i<=n;++i) read(w[i]); dfs(1,0); //树增 for(int i=1;i<=m;++i){ //差分 read(x),read(y); int l=lca(x,y); change(root[x],1,n<<1,n+dep[x],1); change(root[y],1,n<<1,n+2*dep[l]-dep[x],1); change(root[l],1,n<<1,n+dep[x],-1); change(root[fa[l][0]],1,n<<1,n+2*dep[l]-dep[x],-1); } dfs2(1); //合并与统计 for(int i=1;i<=n;++i) printf("%d ",ans[i]); }