树上差分 学习笔记
前置知识:差分
例题:P2367 语文成绩
序列维护区间加,最后询问序列最小值。
线段树
差分即可。
对于在原数列 \(a_u\) 到 \(a_v\) 都加一个 \(x\),考虑在差分数组 \(b\) 中,变化的只有 \(b_u\) 和 \(b_{v+1}\)。
因为在原数列 \(a_u\) 到 \(a_v\) 都加一个 \(x\),对于 \(u\) 之前和 \(v+1\) 之后数的差不会有任何变化。\(u+1\) 到 \(v\) 之前的数也不会有变化,实际上,只有 \(a_u\) 对于 \(a_{u-1}\) 的差相较于之前大了 \(x\),\(a_{v+1}\) 对于 \(a_v\) 的差相较于之前小了 \(x\)。所以维护区间加只需要 \(b_u+x\)、\(b_{v+1}-x\) 即可。
最后,对差分数组 \(b\) 跑一遍前缀和即可还原出原数组。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
long long n,m,a[5000010],b[5000010],s[5000010],ans=1000000000000;
int main(){
long long i,j,u,v;
scanf("%lld %lld",&n,&m);
for(i=1;i<=n;i++) scanf("%lld",&a[i]);
for(i=1;i<=n;i++){
b[i]=a[i]-a[i-1];
}
while(m--){
cin>>u>>v>>j;
b[u]+=j;
b[v+1]-=j;
}
for(i=1;i<=n;i++){
s[i]=b[i]+s[i-1];
}
for(i=1;i<=n;i++){
ans=min(ans,s[i]);
}
cout<<ans<<endl;
return 0;
}
树上差分
点差分
例题:P3128 [USACO15DEC]Max Flow P
在树上给出多条路径,问所有路径经过最多的点经过了多少次。
典型的点差分。
考虑对于差分数组 \(s\),一条从 \(u\) 到 \(v\) 的路径好像是只需要 \(s_u+1,s_v+1,s_{f_{\text{LCA}(u,v)}}-1\) 就行。
但这是不对的。
考虑在树上前缀和恢复原值时,\(\text{LCA(u,v)}\) 的值会因为 \(s_u\) 和 \(s_v\) 都加 \(1\) 而导致多加了一个 \(1\)。所以在点差分时 \(s_{\text{LCA}(u,v)}\) 也应该减 \(1\)。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const long long loglim=20;
long long n,m,h[50010],tot,f[50010][loglim+5],d[50010],s[50010],ans;
struct edge{
long long v,nxt;
}e[100010];
void add(long long u,long long v){
tot++;
e[tot].v=v; e[tot].nxt=h[u];
h[u]=tot;
}
void build(long long fr,long long u,long long dep){
long long i,j,v;
d[u]=dep;
for(i=1;i<=loglim;i++){
f[u][i]=f[f[u][i-1]][i-1];
}
for(i=h[u];i;i=e[i].nxt){
v=e[i].v;
if(v!=fr){
f[v][0]=u;
build(u,v,dep+1);
}
}
}
long long lca(long long u,long long v){
long long i,j;
if(d[u]<d[v]) swap(u,v);
for(i=loglim;i>=0;i--){
if(d[f[u][i]]>=d[v]) u=f[u][i];
}
if(u==v) return u;
for(i=loglim;i>=0;i--){
if(f[u][i]!=f[v][i]){
u=f[u][i];
v=f[v][i];
}
}
return f[u][0];
}
void solve(long long fr,long long u){//树上前缀和
long long i,j,v;
for(i=h[u];i;i=e[i].nxt){
v=e[i].v;
if(v!=fr){
solve(u,v);
s[u]+=s[v];
}
}
}
int main(){
long long i,j,u,v;
cin>>n>>m;
for(i=1;i<n;i++){
cin>>u>>v;
add(u,v);
add(v,u);
}
build(0,1,1);
while(m--){
cin>>u>>v;
s[u]++; //树
s[v]++; //上
j=lca(u,v); //点
s[j]--; //差
s[f[j][0]]--;//分
}
solve(0,1);
for(i=1;i<=n;i++){
ans=max(ans,s[i]);
}
cout<<ans<<endl;
return 0;
}
边差分
树边有非负权,给定多条路径。现可将某边权变为0,使得最长路径最小。输出该最小值。
gx:好难!
zkw:屑!
会树剖的zkw把这题秒了,gx只能去写他151行的树上差分。
直接说正解了:
最大路径最小暗示二分答案,显然这题答案满足单调性。二分答案的时候 \(check(mid)\) 判断能否通过把某条边变成 0 来使得答案小于等于 \(mid\)。
我们把长度(可以预处理)大于 \(mid\) 的路径叫做大路径,反过来就是小路径。我们需要找出一条边是所有大路径都经过的,看看能否通过把它变成0后使所有大路径的长度小于等于 \(mid\)(判断的时候只用判断最长的大路径(可以预处理)是否满足即可)。
判断一条边是否被所有大路径经过可以用边差分。首先把树边经过的次数存到儿子节点处。然后每有一条 \(u\) 到 \(v\) 的路径,只需 \(s_u+1,s_v+1,s_{\text{LCA(u,v)}}-2\) 即可。最后树上前缀和,还原原数组,即每条边经过的次数。
总时间复杂度:\(O(m \log^2 n)\)。
当然可以通过初始化 \(\text{LCA}\) 把时间复杂度降到 \(O(m \log n)\),但这样也能过。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int loglim=18;
int n,m,h[300010],tot,f[300010][loglim+5],g[300010][loglim+5],d[300010],s[300010],w[300010];
struct edge{
int v,w,nxt;
}e[600010];
struct path{
int u,v,op;
}p[300010];
inline int read(){
char ch=getchar();
int x=0,f=1;
while(ch<'0' || ch>'9') {
if(ch=='-')
f=-1;
ch=getchar();
}
while('0'<=ch && ch<='9') {
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
inline void add(int u,int v,int w){
tot++;
e[tot].v=v; e[tot].w=w; e[tot].nxt=h[u];
h[u]=tot;
}
inline void build(int fr,int u,int dep){
register int i,j,v;
d[u]=dep;
for(i=1;i<=loglim;i++){
f[u][i]=f[f[u][i-1]][i-1];
g[u][i]=g[u][i-1]+g[f[u][i-1]][i-1];
}
for(i=h[u];i;i=e[i].nxt){
v=e[i].v;
if(v!=fr){
f[v][0]=u;
g[v][0]=e[i].w;
w[v]=e[i].w;
build(u,v,dep+1);
}
}
}
inline int lca(int u,int v){
register int i,j;
if(d[u]<d[v]) swap(u,v);
for(i=loglim;i>=0;i--){
if(d[f[u][i]]>=d[v]){
u=f[u][i];
}
}
if(u==v) return u;
for(i=loglim;i>=0;i--){
if(f[u][i]!=f[v][i]){
u=f[u][i];
v=f[v][i];
}
}
return f[u][0];
}
inline int G(int u,int v){
register int i,j,now=0;
if(d[u]<d[v]) swap(u,v);
for(i=loglim;i>=0;i--){
if(d[f[u][i]]>=d[v]){
now+=g[u][i];
u=f[u][i];
}
}
if(u==v) return now;
for(i=loglim;i>=0;i--){
if(f[u][i]!=f[v][i]){
now+=g[u][i];
now+=g[v][i];
u=f[u][i];
v=f[v][i];
}
}
now+=g[u][0];
now+=g[v][0];
return now;
}
inline void solve(int fr,int u){
register int i,j,v;
for(i=h[u];i;i=e[i].nxt){
v=e[i].v;
if(v!=fr){
solve(u,v);
s[u]+=s[v];
}
}
}
inline bool cmpp(path u,path v){
return u.op>v.op;
}
inline bool check(int gx){
register int i,j,u,v,sum=0;
memset(s,0,sizeof(s));
for(i=1;i<=m;i++){
if(p[i].op<=gx) break;
else sum++;
u=p[i].u; v=p[i].v;
s[u]++; s[v]++;
s[lca(u,v)]-=2;
}
solve(0,1);
if(sum==0) return 1;
for(i=1;i<=n;i++){
if(s[i]==sum){
if(p[1].op-w[i]<=gx) return 1;
}
}
return 0;
}
int main(){
register int i,j,u,v,l,r,mid;
cin>>n>>m;
for(i=1;i<n;i++){
u=read(); v=read(); j=read();
add(u,v,j);
add(v,u,j);
}
build(0,1,1);
for(i=1;i<=m;i++){
p[i].u=read(); p[i].v=read();
u=p[i].u; v=p[i].v;
p[i].op=G(u,v);
}
sort(p+1,p+m+1,cmpp);
l=0;r=100000000000;
while(l<r){
mid=(l+r)>>1;
if(check(mid)){
r=mid;
}
else{
l=mid+1;
}
}
cout<<l<<endl;
return 0;
}