SDSC整理(Day2 树上倍增和差分)
树上倍增与差分
在 \(SDSC \ Day2\) 中树上倍增和差分这两个知识点的理论部分老师讲课的时间不超过 \(5\) 分钟。。。。。。
但是近年来 \(noip\) 提高组考树上差分的题倒是考得挺多的。
可以看出这几个知识点也是非常重要的,网上在这方面的知识比较少(oi wiki:你直接点我名算了),所以此文章是我个人的一些见解,不保证正确性。。。
树上倍增
\(Pre-knowledge\)
简单倍增,不会请退役。
\(description\)
树上倍增,顾名思义就是在树上进行倍增,其实现也非常简单。
我们维护一个倍增数组 \(ST[i][j]\) 表示 \(i\) 结点向上跳 \(2^j\) 步所到达的位置。
更新倍增数组时只有一行操作:
其意思也非常明显,就是从这个节点往上跳 \(2^i\) 步就相当于先向上跳 \(2^{i-1}\) 步,然后再跳 \(2^{i-1}\) 步。
那么很明显, \(ST\) 数组的初值就是 \(ST[i][0]=fa\) ,表示当前位置向上跳一步就是他的父亲。
其实我感觉和 \(ST\) 表的原理是差不多的,都是利用了倍增的思想。
树上倍增的一个常见应用就是求 \(LCA\) ,相信大家都会,这里直接给代码
/*
倍增求LCA
date:2022.7.3
worked by respect_lowsmile
*/
#include<iostream>
using namespace std;
const int N=5e5+5;
int n,m,s,cnt,a,b;
struct node
{
int to,next;
};
node edge[N<<1];
int head[N],deep[N],ST[N][25],lg[N];
inline int read()
{
int s=0,w=1;char ch=getchar();
while(ch<'0'||ch>'9')
{ if(ch=='-') w=-1; ch=getchar();}
while(ch>='0'&&ch<='9')
{ s=s*10+ch-'0'; ch=getchar();}
return s*w;
}
void add(int u,int v)
{
cnt++;
edge[cnt].to=v;
edge[cnt].next=head[u];
head[u]=cnt;
}
void dfs(int now,int fa)
{
ST[now][0]=fa,deep[now]=deep[fa]+1;
//cout<<"ceshi:"<<now<<" "<<deep[now]<<endl;
for(int i=1;i<=lg[deep[now]];++i)
ST[now][i]=ST[ST[now][i-1]][i-1];
for(int i=head[now];i;i=edge[i].next)
if(edge[i].to!=fa) dfs(edge[i].to,now);
}
int LCA(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
while(deep[x]>deep[y])
{
x=ST[x][lg[deep[x]-deep[y]]-1];
//cout<<x<<" ";
}
if(x==y) return x;
for(int j=lg[deep[x]]-1;j>=0;--j)
if(ST[x][j]!=ST[y][j])
x=ST[x][j],y=ST[y][j];
return ST[x][0];
}
int main()
{
n=read(),m=read(),s=read();
for(int i=1;i<=n-1;++i)
{
int x,y;
x=read(),y=read();
add(x,y),add(y,x);
}
for(int i=1;i<=n;++i)
lg[i]=lg[i-1]+(1<< lg[i-1]==i);
dfs(s,0);
for(int i=1;i<=m;++i)
{
a=read(),b=read();
printf("%d\n",LCA(a,b));
}
return 0;
}
树上差分
还记得树链剖分那篇文章在引入树链剖分时提到的树上差分可做的问题吗?我们可以用树上差分轻松地解决
将树中从 \(x\) 到 \(y\) 的最短路径的所有点的权值都加上 \(z\)
这一操作。
\(Pre-knowledge\)
差分,这里帮忙推一下朋友的博客 link
其实不会的前置知识全部可以问gym学长
\(description\)
点差分
我们用数组 \(val\) 来存储每个点被访问的次数
那么对于从 \(x\) 到 \(y\) 的所有点,我们只需要 \(val[x]++\) , \(val[y]++\) , \(val [lca(x,y)]--\) , \(val[fa(lca(x,y))]--\) 即可。
然后 \(dfs\) 把每一个节点的值都加到他的父亲上,修改操作就完成了。
我们来看一个例子:
其实我们也可以这样理解,当我们这个节点 \(+1\) 之后,因为我们 \(dfs\) 时要把儿子累加的父亲,所以我们就相当于从这个点一直到根节点全部都 \(+1\)
但是,我们只用加到 \(LCA\) 就可以了,所以 \(val[fa[lca(x,y)]] --\)
但是我们起点到 \(lca\) 和终点到 \(lca\) 这两条路径都会这样处理,这就导致我们的 \(lca\) 这个地方加了两次,所以要再 \(val[lca]--\) 。
扔下一道板子题
code
/*
树上差分(点差分)
date:2022.7.26
worked by respect_lowsmile
*/
#include<iostream>
using namespace std;
const int N=3e5+5;
struct node
{
int to,next;
};
node edge[N<<1];
int num,n;
int head[N],ST[N][25],val[N],a[N],deep[N],lg[N];
void add(int u,int v)
{
num++;
edge[num].to=v;
edge[num].next=head[u];
head[u]=num;
}
void dfs(int now,int f)
{
deep[now]=deep[f]+1;
ST[now][0]=f;
for(int i=1;i<=lg[deep[now]];i++)
ST[now][i]=ST[ST[now][i-1]][i-1];
for(int i=head[now];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==f) continue;
dfs(v,now);
}
}
int LCA(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
while(deep[x]>deep[y])
x=ST[x][lg[deep[x]-deep[y]]-1];
if(x==y) return x;
for(int i=lg[deep[x]];i>=0;--i)
{
if(ST[x][i]!=ST[y][i])
x=ST[x][i],y=ST[y][i];
}
return ST[x][0];
}
void get_sum(int now,int f)
{
for(int i=head[now];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==f) continue;
get_sum(v,now);
val[now]+=val[v];
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;++i)
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
for(int i=1;i<=n;++i)
scanf("%d",&a[i]);
for(int i=1;i<n;++i)
{
int u,v;
scanf("%d %d",&u,&v);
add(u,v),add(v,u);
}
dfs(1,0);
for(int i=1;i<n;++i)
{
int lca=LCA(a[i],a[i+1]);
val[a[i]]++,val[a[i+1]]++,val[lca]--,val[ST[lca][0]]--;
}
get_sum(1,0);
for(int i=2;i<=n;++i)
val[a[i]]--;
for(int i=1;i<=n;++i)
printf("%d\n",val[i]);
return 0;
}
例题
边差分
如果你看过我的基于边权的树链剖分那篇文章的话,相信你一定知道该怎么做((大雾
我们可以把所有的边权都压入深度较深的点中,这样就把边权转化成了点权。
因为边权被压入了深度较深的点,所以我们的 \(val[lca]\) 存储的是 \(lca\) 的上一条边,是不在我们的修改范围之内的
所以我们在差分的时候直接 \(val[x]++\) , \(val[y]++\) , \(val[lca(x,y)]-=2\) 即可。
扔下一道例题
P6869 [COCI2019-2020#5] Putovanje
code
/*
树上差分(边差分)
date:2022.7.26
worked by respect_lowsmile
*/
#include<iostream>
#define int long long
using namespace std;
const int N=2e5+5;
struct node
{
int to,next,w1,w2;
};
node edge[N<<1];
int num,n,ans;
int head[N],deep[N],ST[N][25],lg[N],val[N],c1[N],c2[N];
int Min(int a,int b)
{
return a<b?a:b;
}
void add(int u,int v,int w1,int w2)
{
num++;
edge[num].to=v;
edge[num].next=head[u];
edge[num].w1=w1;
edge[num].w2=w2;
head[u]=num;
}
void dfs(int now,int f)
{
deep[now]=deep[f]+1;
ST[now][0]=f;
for(int i=1;i<=lg[deep[now]];i++)
ST[now][i]=ST[ST[now][i-1]][i-1];
for(int i=head[now];i;i=edge[i].next)
{
int v=edge[i].to,w1=edge[i].w1,w2=edge[i].w2;
if(v==f) continue;
c1[v]=w1,c2[v]=w2;
dfs(v,now);
}
}
int LCA(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
while(deep[x]>deep[y])
x=ST[x][lg[deep[x]-deep[y]]-1];
if(x==y) return x;
for(int i=lg[deep[x]];i>=0;i--)
{
if(ST[x][i]!=ST[y][i])
x=ST[x][i],y=ST[y][i];
}
return ST[x][0];
}
void get_sum(int now,int f)
{
for(int i=head[now];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==f) continue;
get_sum(v,now);
val[now]+=val[v];
}
}
signed main()
{
scanf("%lld",&n);
for(int i=1;i<=n;++i)
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
for(int i=1;i<n;++i)
{
int u,v,w1,w2;
scanf("%lld %lld %lld %lld",&u,&v,&w1,&w2);
add(u,v,w1,w2),add(v,u,w1,w2);
}
dfs(1,0);
for(int i=1;i<n;++i)
{
int lca=LCA(i,i+1);
val[i]++,val[i+1]++,val[lca]-=2;
}
get_sum(1,0);
for(int i=1;i<=n;++i)
{
if(val[i]==0) continue;
if(val[i]==1) ans+=c1[i];
if(val[i]>1) ans+=Min(c2[i],c1[i]*val[i]);
}
printf("%lld",ans);
return 0;
}
例题的话没找快逃,做一下上面的 \(noip\) 的题目就可以了。
continue......