树上差分
在某些情况下可以和树剖替换。(当然某些情况替代不了,我愿称之为小树剖)
前置知识
- 倍增求LCA
- 了解差分思想
(既然要替代树剖怎么可能会用树剖求LCA呢/xyx)
引入
首先我们都知道差分是可以用来通过单点修改维护区间信息的(如单修区查的树状数组),因为我们在前面加上,再在后面减掉,中间这一段是得到了一个区间加的收益的。
那么我们考虑把这个区间转到树上。
分析
我们考虑对于一条路径区间加,然后思考如何差分。
我们记路径两个端点为 \(x\), \(y\)。不难发现树上路劲总是由路径 \((x,LCA(x,y) )\) 和路径 \((y,LCA(x,y) )\) 组成,显然这个 \(LCA(x,y)\) 是突破口。
可以效仿普通的差分在 \(x\),\(y\) 处加上需要增加的值,发现在其 \(LCA\) 以上的部分是不应该享受到增益的,所以我们在其 \(LCA\) 减掉这个值即可。
但是我们很快发现不对劲,\(LCA\) 处的值计算了两遍,我们只需要在 \(LCA\) 处也减去这个值即可。
具体题目分析
模板题,不多说了,直接按照分析来即可。
#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 50000
#define INF (int)(1e9)
inline int read(){
int s=0,f=1;char ch=getchar();
while(ch<'0'||'9'<ch) {if(ch=='-') f=-1;ch=getchar();}
while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=getchar();}
return s*f;
}
inline void write(int x){
if(x<0) x=-x,putchar('-');
static char Wstack[20],top=0;
do{Wstack[top]=x%10;x/=10;++top;} while(x);
while(top) {putchar(Wstack[--top]+'0');}
}
int n,m;
vector<int>head,to,nxt;
void join(int u,int v){
nxt.push_back(head[u]);
head[u]=to.size();
to.push_back(v);
}
int d[N+3];
int fa[N+3][20+3],dep[N+3];
void dfs1(int u,int f){
for(int i=0;i<20;++i){
fa[u][i+1]=fa[fa[u][i]][i];
}
for(int i=head[u];~i;i=nxt[i]){
if(to[i]==f) continue;
fa[to[i]][0]=u;
dep[to[i]]=dep[u]+1;
dfs1(to[i],u);
}
}
inline void Swap(int &x,int &y){
int tmp=x;x=y;y=tmp;
}
int LCA(int x,int y){
if(dep[x]<dep[y]) Swap(x,y);
//printf("%d %d\n",x,y);
for(int i=20;i>=0;--i){
if(dep[fa[x][i]]>=dep[y]){
x=fa[x][i];
}
if(x==y) return x;
}
for(int i=20;i>=0;--i){
if(fa[x][i]!=fa[y][i]){
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
int ans=-INF;
inline int Max(int x,int y){
return x>y?x:y;
}
int dfs(int u){
int res=d[u];
for(int i=head[u];~i;i=nxt[i]){
if(to[i]==fa[u][0]) continue;
res+=dfs(to[i]);
}
ans=Max(ans,res);
return res;
}
int main(){
n=read();m=read();
head.resize(n+3,-1);
for(int i=1;i<n;++i){
int u=read(),v=read();
join(u,v);join(v,u);
}
dfs1(1,1);
for(int i=1;i<=m;++i){
int x=read(),y=read();
++d[x];++d[y];
int f=LCA(x,y);
//printf("%d-%d->%d\n",x,y,LCA(x,y));
--d[f];--d[fa[f][0]];
}
dfs(1);
write(ans);
return 0;
}
稍微一点点不是模板的模板(我这说的什么)
因为最后一次走到的点不需要取糖,并且为了防止每一次的始终点重复取,考虑对于终点进行处理。如果终点在起点的父亲链上,那么直接取终点到起点的路径上的儿子为真正的终点即可;如果不在则直接取终点的父亲为真正的终点,其他模板。
#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 300000+3
inline int read(){
int s=0,f=1;
char ch=getchar();
while(ch<'0'||'9'<ch) {if(ch=='-') f=-1;ch=getchar();}
while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=getchar();}
return s*f;
}
inline void write(int x){
if(x<0) x=-x,putchar('-');
static char Wstack[20],top=0;
do{Wstack[top]=x%10;x/=10;++top;}while(x);
while(top){putchar(Wstack[--top]+'0');}
}
int n,a[N],d[N];
vector<int>head,to,nxt;
void join(int u,int v){
nxt.push_back(head[u]);
head[u]=to.size();
to.push_back(v);
}
int fa[N][20+3],dep[N];
void dfs1(int u,int f){
dep[u]=dep[f]+1;
for(int i=0;i<20;++i){
fa[u][i+1]=fa[fa[u][i]][i];
}
for(int i=head[u];~i;i=nxt[i]){
if(to[i]==f) continue;
fa[to[i]][0]=u;
dfs1(to[i],u);
}
}
inline void Swap(int &x,int &y){
int tmp=x;x=y;y=tmp;
}
int LCA(int x,int y){
if(dep[x]<dep[y]) Swap(x,y);
for(int i=20;i>=0;--i){
if(dep[fa[x][i]]>=dep[y]){
x=fa[x][i];
}
if(x==y) return x;
}
for(int i=20;i>=0;--i){
if(fa[x][i]!=fa[y][i]){
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
int ans[N];
int dfs(int u){
int res=d[u];
for(int i=head[u];~i;i=nxt[i]){
if(to[i]==fa[u][0]) continue;
res+=dfs(to[i]);
}
ans[u]=res;
return res;
}
int main(){
n=read();
head.resize(n+3,-1);
for(int i=1;i<=n;++i){
a[i]=read();
}
for(int i=1;i<n;++i){
int u=read(),v=read();
join(u,v);join(v,u);
}
dfs1(1,1);
for(int i=1;i<n;++i){
int x=a[i],y=a[i+1];
int t=x;
for(int j=20;j>=0;--j){
if(dep[fa[t][j]]>=dep[y]){
t=fa[t][j];
}
}
if(t==y){
int t=x;
for(int j=20;j>=0;--j){
if(dep[fa[t][j]]>dep[y]){
t=fa[t][j];
}
}
y=t;
}else{
y=fa[y][0];
}
int f=LCA(x,y);
++d[x];++d[y];
--d[f];--d[fa[f][0]];
//printf("%d %d %d %d\n",x,y,f,fa[f][0]);
}
dfs(1);
/*
for(int i=1;i<=n;++i){
printf("%d:%d\n",i,d[i]);
}
printf("%d\n",LCA(1,2));
*/
for(int i=1;i<=n;++i){
write(ans[i]);
putchar('\n');
}
return 0;
}