Luogu P1600 天天爱跑步 树上差分

Luogu P1600 天天爱跑步

树上差分


题目链接
树上问题
没看出怎么差分
被观察到的条件有两个
lca前一半(包括lca) \(d[S_i]-d[x]=w[x]\)
\(d[i]\)表示节点深度
lca后一半 \(d[S_i]+d[x]-2*d[lca(S_i,T_i)]=w[x]\)
但是具体怎么实现这个公式??
实现 \(d[S_i]=w[x]+d[x]\)
可以转化为线段树合并模型
\(S_i\)\(lca\)的路径上增加\(d[S_i]\)的价值
最后求出各点的\(w[x]+d[x]\)的价值个数
但是好像有点开不下。
可以在每一个节点上建一个vector记录投放
然后一个全局数组c
在树上dfs的时候记录\(c[w[x]+d[x]]\)递归回来的时候和原来的做差就是答案。


代码如下:

#include<bits/stdc++.h>
#define mk make_pair
using namespace std;
const int maxn=300000;
int n,m,head[maxn],tot,w[maxn],lc[maxn],ans[maxn];
int id[maxn],d[maxn],fa[maxn],f[maxn],c1[maxn*2],c2[maxn*2];
struct node{
	int nxt,to;
	#define nxt(x) e[x].nxt
	#define to(x) e[x].to
}e[maxn<<1];
inline void add(int from,int to){
	to(++tot)=to;nxt(tot)=head[from];head[from]=tot;
}
inline int find(int x){return fa[x]==x ? fa[x] : fa[x]=find(fa[x]);}
vector<pair<int,int> > pt[maxn];
vector<int> a1[maxn],a2[maxn],b1[maxn],b2[maxn];
pair<int,int> di[maxn];
void tarjan(int x){
	id[x]=1;
	for(int i=head[x];i;i=nxt(i)){
		if(id[to(i)]) continue;
		d[to(i)]=d[x]+1;
		tarjan(to(i));
		fa[to(i)]=x;f[to(i)]=x;
	}
	for(unsigned int i=0;i<pt[x].size();i++){
		int to=pt[x][i].first,vl=pt[x][i].second;
		if(id[to]==2) lc[vl]=find(to);
	}
	id[x]=2;
}
void dfs(int x){
	int val1=c1[d[x]+w[x]],val2=c2[w[x]-d[x]+n];
	id[x]=1;
	for(int i=head[x];i;i=nxt(i)){
		if(id[to(i)]) continue;
		dfs(to(i));
	}
	for(int i=0;i<a1[x].size();i++)
		c1[a1[x][i]]++;
	for(int i=0;i<b1[x].size();i++)
		c1[b1[x][i]]--;
	for(int i=0;i<a2[x].size();i++)
		c2[n+a2[x][i]]++;
	for(int i=0;i<b2[x].size();i++)
		c2[n+b2[x][i]]--;
	ans[x]=c1[d[x]+w[x]]+c2[w[x]-d[x]+n]-val1-val2;
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<n;i++){
		int x,y;scanf("%d%d",&x,&y);
		add(x,y);add(y,x);
	}
	for(int i=1;i<=n;i++) fa[i]=i,scanf("%d",&w[i]);
	for(int i=1;i<=m;i++){
		int x,y;scanf("%d%d",&x,&y);
		if(x==y) lc[i]=x;
		pt[x].push_back(mk(y,i));pt[y].push_back(mk(x,i));
		di[i].first=x;di[i].second=y;
	}
	d[1]=1;tarjan(1);
	for(int i=1;i<=m;i++){
		int x=di[i].first,y=di[i].second;
		a1[x].push_back(d[x]);a2[y].push_back(d[x]-2*d[lc[i]]);
		b1[f[lc[i]]].push_back(d[x]);b2[lc[i]].push_back(d[x]-2*d[lc[i]]);
	}
	memset(id,0,sizeof(id));
	dfs(1);
	for(int i=1;i<=n;i++) printf("%d ",ans[i]);
	return 0;
}
posted @ 2019-10-03 14:47  ChrisKKK  阅读(128)  评论(0编辑  收藏  举报