讲义:浅谈树上问题的一些乱搞做法

鸣谢 @pig1121

\(\bf{课件}\)

树上启发式合并

树上启发式合并,即 \(\text{dsu on tree}\),是一种复杂度和实现都非常优秀的处理树上静态问题的算法,应用极广泛。

何为启发式?意思就是凭人类智慧想出来的神奇的优化/保证时间复杂度的算法。

比如并查集的启发式合并

void merge (int x, int y) {
	x = find(x), y = find(y);
	if (siz[x] < siz[y]) swap(x, y);
	fa[y] = x, siz[x] += siz[y];
}

CF600E Lomsat gelral

我们用这道例题来讲解 \(\text{dsu on tree}\) 的基本姿势。

\(cnt[i]\) 为 当前颜色 \(i\) 的出现次数,\(mx\) 为当前颜色出现次数的最大值,\(sum\) 为当前的答案。

显然直接做是 \(O(n^2)\) 的,考虑优化子树内信息的合并

借助启发式合并的思想,我们这么做:

  1. \(\text{dfs}\) 到节点 \(u\) 时,首先继续搜 \(u\) 的所有轻儿子的子树,计算答案,但不保留对 \(cnt[i]~/~mx~/~sum\) 的贡献

  2. \(u\) 的重儿子的子树,计算答案的同时保留其对 \(cnt[i]~/~mx~/~sum\) 的贡献

  3. 再搜一遍 \(u\) 的所有轻儿子的子树 ,不重复计算答案,保留其对 \(cnt[i]~/~mx~/~sum\) 的贡献

em… 时间复杂度呢?

易知每次遍历到一个节点时的操作 \(O(1)\) 的,思考每个节点被搜到了多少次。注意到 \(dfs\) 的函数体内,重儿子的子树会被遍历 \(1\) 次,而轻儿子的子树会被遍历 \(2\) 次。所以一个节点的遍历次数的大 \(O\) 界就是它到根节点得路径上的轻边个数,这是 \(O(\log n)\) 的,于是时间复杂度 \(O(n\log n)\)

#include<bits/stdc++.h>
#define pb push_back
#define is insert
#define fi first
#define se second
#define mkp make_pair
#define INF INT_MAX
#define mathmod(a,m) (((a)%(m)+(m))%(m))
#define mem(a,b) memset(a,b,sizeof a)
#define cpy(a,b) memcpy(a,b,sizeof b)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int N=1e5+5;
int n;
ll col[N];
int tim,st[N],ed[N],rk[N];
int siz[N],son[N];
ll sum,mx,cnt[N],ans[N];
vector<int> G[N];
void dfs1(int u,int fa){
	siz[u]=1,son[u]=-1;
	st[u]=++tim,rk[tim]=u;
	for(int v:G[u]){
		if(v==fa) continue;
		dfs1(v,u);
		siz[u]+=siz[v];
		if(!(~son[u])||siz[v]>siz[son[u]]) son[u]=v;
	}
	ed[u]=tim;
}
void update(int u,int fa,int x,int ch){
	cnt[col[u]]+=x;
	if(cnt[col[u]]>mx) mx=cnt[col[u]],sum=col[u];
	else if(cnt[col[u]]==mx) sum+=col[u];
	for(int v:G[u]) if(v!=fa&&v!=ch) update(v,u,x,ch);
}
void dfs2(int u,int fa,bool flag){
	for(int v:G[u]) if(v!=fa&&v!=son[u]) dfs2(v,u,0);
	if(~son[u]) dfs2(son[u],u,1);
	update(u,fa,1,son[u]);
	ans[u]=sum;
	if(!flag) update(u,fa,-1,-1),sum=mx=0;
}
int main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;
	for(int i=1;i<=n;i++) cin>>col[i];
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		G[u].pb(v),G[v].pb(u);
	}
	dfs1(1,-1);
	dfs2(1,-1,1);
	for(int i=1;i<=n;i++) cout<<ans[i]<<" ";
	cout<<endl;
	return 0;
}

线段树合并 - 单点修改版\(^{[1]}\)

引入

考虑这么一个问题:

  • 一个 \(n\) 个节点的有根树,每个节点 \(i\) 有一个颜色 \(c_i\)
  • 对于每个 \(i\in[1,n]\),求出以 \(i\) 为根的子树中,编号在 \([l_i,r_i]\) 的颜色的出现次数和。
  • \(n\le 10^5\)\(c_i\le n\)

显然我们可以 \(\text{dsu~on~tree}\),用个线段树/树状数组解决问题,时间复杂度 \(O(n\log^2n)\),两只 \(\log\)

有没有单只 \(\log\) 的做法呢?

我会把子树条件转成 \(dfn\) 然后二维数点

\(\text{dsu on tree}\) 本质上还是优化后的暴力合并,没有具体的数据结构进行优化,那比方说,对于线段树这种树形数据结构,有没有什么单次/均摊 \(O(n\log n)\) 的做法呢?

分析

考虑对两棵动态开点线段树进行合并,这里我们默认合并是对应位置相加

线段树合并有两种写法:不可持久化的可持久化的,下面我们讲解不可持久化的线段树合并,可持久化时只需每次合并时新建节点即可,就略过了。

最最暴力的做法:直接 \(O(n)\) 暴力 \(\text{dfs}\),对应位置相加。

但是实际上当标记很复杂时,直接对应位置合并不是太好做。一般来讲,对应叶子节点的合并要比对应的大区间合并好做,所以我们可以搜到叶子节点后,暴力合并,然后一路 \(\text{pushup}\) 上来,实现:

int merge (int u, int v, int l, int r) {
	if (l == r) {
		sum[u] += sum[v];
		return u;
	}
	int mid = l + r >> 1;
	lc[u] = merge(lc[u], lc[v], l, mid);
	rc[u] = merge(rc[u], rc[v], mid + 1, r);
	pushup(u);
	return u;
}

注意到一个优化:由于是动态开点线段树,所以整棵树大概不是满的。那对于空儿子,我们在合并时显然不用管,也不需要继续递归下去合并了,就像这样:

int merge (int u, int v, int l, int r) {
	if (!u || !v) return u | v;
	if (l == r) {
		sum[u] += sum[v];
		return u;
	}
	int mid = l + r >> 1;
	lc[u] = merge(lc[u], lc[v], l, mid);
	rc[u] = merge(rc[u], rc[v], mid + 1, r);
	pushup(u);
	return u;
}

这个东西的时间复杂度呢?
结论:对于线段树 \(T_1,T_2,\dots,T_n\),将它们以任意顺序合并的复杂度为 \(O(\sum_{i=1}^{n}\lvert T_i\rvert-\lvert\sum_{i=1}^{n}T_i\rvert)\)

证明:懒得写,可以参考这篇博客\(^{[2]}\)

推论:对于值域为 \(n\) 的若干线段树,进行 \(q\) 次单点修改操作,共创建 \(O(q\log n)\) 个节点,时间复杂度为 \(O(q\log n-\lvert\sum_{i=1}^{n}T_i\rvert)=O(q\log n)\)

\(\red{p.s.}\)(接下来是我个人的一些经验之谈,不喜轻喷)

可见,对于线段树合并的复杂度分析,我们一般只需关注合并之前所创建的结点数量,所以就线段树合并算法本身来说,它在时间复杂度上一般不会成为该问题的决定性瓶颈常数大不大就是另外一回事了

还有,在思考线段树合并是否适用时,我们只需关注两件事:能否快速地合并两个叶子节点信息快速 \(\bold{pushup}\),而 \(95\%\) 的线段树问题都能做到这两点。

注意:线段树合并的优势不仅在其灵活性上,更重要的是它采用动态开点的方式,这允许它自由的分配空间,所以有时我们的线段树不一定要维护什么复杂的内容,它可以用来替换空间静态的数组,甚至有时无需 \(\text{pushup}\),采用 \(leafy\) 的线段树,只借助非叶子节点存储懒标记等。

一般来说,\(\text{dsu on tree}\) 能做的题线段树合并都能做,而且线段树本身就是数据结构,想起来当然自然一点(个人观点,轻喷),时间复杂度有时能少一只 \(\log\)但跑的一般都差不多快。

P4556 [Vani 有约会] 雨天的尾巴 /【模板】线段树合并

暴力显然,对于每一个节点维护 \(cnt\) 数组表示每个颜色的出现次数,这显然可以用树上差分做,但最后的合并是 \(O(nC)\) 的。

考虑使用动态开点线段树,替换 \(cnt\) 数组,\(cnt\) 数组的暴力合并换成线段树合并。

由前面的讨论知道,这部分复杂度是 \(O(q\log C)\) 的。

于是时间复杂度 \(O(n+q\log C)\)

// godmoo's code
#include<bits/stdc++.h>
#define pb push_back 
using namespace std;
const int N=1e5+5;
const int V=6e6+5;
const int LG=20;
const int MX=1e5;
int n,m,ans[N],dep[N],fa[LG][N];
vector<int> G[N];
int tot,rt[N],lc[V],rc[V],mx[V],val[V];
void pushup(int u){
	mx[u]=max(mx[lc[u]],mx[rc[u]]);
	val[u]=mx[u]==mx[lc[u]]?val[lc[u]]:val[rc[u]];
}
void add(int &u,int l,int r,int pos,int x){
	if(!u) u=++tot;
	if(l==r){mx[u]+=x,val[u]=pos;return;}
	int mid=l+r>>1;
	if(pos<=mid) add(lc[u],l,mid,pos,x);
	else add(rc[u],mid+1,r,pos,x);
	pushup(u);
}
int merge(int u,int v,int l,int r){
	if(!u||!v) return u|v;
	if(l==r){mx[u]+=mx[v],val[u]=l;return u;}
	int mid=l+r>>1;
	lc[u]=merge(lc[u],lc[v],l,mid);
	rc[u]=merge(rc[u],rc[v],mid+1,r);
	pushup(u);return u;
}
void calc(int u,int fat){
	dep[u]=dep[fa[0][u]=fat]+1;
	for(int i=1;i<=__lg(n);i++) fa[i][u]=fa[i-1][fa[i-1][u]];
	for(int v:G[u]) if(v!=fat) calc(v,u);
}
int lca(int u,int v){
	if(dep[u]<dep[v]) swap(u,v);
	for(int d=dep[u]-dep[v],i=0;d;d>>=1,i++) if(d&1) u=fa[i][u];
	if(u==v) return u;
	for(int i=__lg(n);~i;i--) if(fa[i][u]!=fa[i][v]) u=fa[i][u],v=fa[i][v];
	return fa[0][u];
}
void dfs(int u,int fat){
	for(int v:G[u]){
		if(v==fat) continue;
		dfs(v,u),rt[u]=merge(rt[u],rt[v],1,MX);
	}
	ans[u]=mx[rt[u]]?val[rt[u]]:0;
}
int main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>m;
	for(int i=1,u,v;i<n;i++) cin>>u>>v,G[u].pb(v),G[v].pb(u);
	calc(1,0);
	for(int i=1,u,v,col,_;i<=m;i++){
		cin>>u>>v>>col,_=lca(u,v);
		add(rt[u],1,MX,col,1);
		add(rt[v],1,MX,col,1);
		add(rt[_],1,MX,col,-1);
		if(fa[0][_]) add(rt[fa[0][_]],1,MX,col,-1);
	}
	dfs(1,0);
	for(int i=1;i<=n;i++) cout<<ans[i]<<"\n";
	return 0;
}

P1600 [NOIP2016 提高组] 天天爱跑步

这题属于是智商不够数据结构来凑的题。

首先将离线下来,询问挂到节点上。

对于每个节点 \(u\),考虑谁对 \(u\) 有贡献,简单分讨一下:

\(\alpha.~\) \(u\) 被一条路径经过,其中该路径满足 \(s\in\operatorname{subtree}(u)\)

那么显然有 \(dep[s]=dep[u]+w[u]\),开线段树记录 \(dep[s]\) 即可,线段树合并维护。

\(\beta.~\) \(u\) 被一条路径经过,其中该路径满足 \(t\in\operatorname{subtree}(u)\)

设走这条路径耗时为 \(tim\),那么有 \(tim=dep[t]-dep[u]+w[u]\),移项得 \(tim-dep[t]=w[u]-dep[u]\),同理扔到线段树里维护即可。

\(tim-dep[t]=w[u]-dep[u]\)

综上……等等,这样貌似会算重诶?!

\(s,t\in\operatorname{subtree}(u)\) 时,就会算重,也就是说:对于 \(s,t\),他们的公共祖先都会算重。

但注意到 \(s,t\) 实际上只会对 \(\operatorname{lca}(s,t)\) 产生贡献,于是先对 \(\operatorname{lca}(s,t)\) 特殊处理,即随便在一棵线段树中删掉,统计完答案后再在另一棵线段树中删掉即可,这点可以用类似树上差分的方法做。

时间复杂度 \(O(m\log n)\),可以通过。

// godmoo's code
#include<bits/stdc++.h>
#define pb push_back
#define is insert
#define fi first
#define se second
#define mkp make_pair
#define INF INT_MAX
#define mathmod(a,m) (((a)%(m)+(m))%(m))
#define mem(a,b) memset(a,b,sizeof a)
#define cpy(a,b) memcpy(a,b,sizeof b)
#define ck(a,b) (dep[a]<dep[b]?(a):(b))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int N=3e5+5;
const int V=2e7+5;
const int LG=20;
int n,m,tim,w[N],st[LG][N<<1],dfn[N],dep[N],ans[N],f[N];
vector<int> G[N];
struct Query{int s,t,w,l;}q[N];
void dfs(int u,int fa){
	st[0][dfn[u]=++tim]=u;
	dep[u]=dep[fa]+1;
	f[u]=fa;
	for(int v:G[u]){
		if(v==fa) continue;
		dfs(v,u);
		st[0][++tim]=u;
	}
}
int lca(int u,int v){
	if((u=dfn[u])>(v=dfn[v])) swap(u,v);
	int p=__lg(v-u+1);
	return ck(st[p][u],st[p][v-(1<<p)+1]);
}
int tot,rt[N][2],lc[V],rc[V],val[V];
void add(int &u,int l,int r,int pos,int x){
	if(!u) u=++tot;
	val[u]+=x;
	if(l==r) return;
	int mid=l+r>>1;
	if(pos<=mid) add(lc[u],l,mid,pos,x);
	else add(rc[u],mid+1,r,pos,x);
}
int query(int u,int l,int r,int pos){
	if(!u) return 0;
	if(l==r) return val[u];
	int mid=l+r>>1;
	if(pos<=mid) return query(lc[u],l,mid,pos);
	else return query(rc[u],mid+1,r,pos);
}
int merge(int u,int v,int l,int r){
	if(!u||!v) return u+v;
	if(l==r){val[u]+=val[v];return u;}
	int mid=l+r>>1;
	lc[u]=merge(lc[u],lc[v],l,mid);
	rc[u]=merge(rc[u],rc[v],mid+1,r);
	val[u]=val[lc[u]]+val[rc[u]];
	return u;
}
void work(int u,int fa){
	for(int v:G[u]){
		if(v==fa) continue;
		work(v,u);
		rt[u][0]=merge(rt[u][0],rt[v][0],1,n);
		rt[u][1]=merge(rt[u][1],rt[v][1],-n,n<<1);
	}
	ans[u]=(dep[u]+w[u]>n?0:query(rt[u][0],1,n,dep[u]+w[u]))+query(rt[u][1],-n,n<<1,w[u]-dep[u]);
}
int main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>m;
	for(int i=1,u,v;i<n;i++) cin>>u>>v,G[u].pb(v),G[v].pb(u);
	dfs(1,0);
	for(int i=1;i<=19;i++){
		for(int j=1;j+(1<<i)-1<=tim;j++){
			st[i][j]=ck(st[i-1][j],st[i-1][j+(1<<i-1)]);
		}
	}
	for(int i=1;i<=n;i++) cin>>w[i];
	for(int i=1;i<=m;i++){
		cin>>q[i].s>>q[i].t;
		q[i].w=lca(q[i].s,q[i].t);
		q[i].l=dep[q[i].s]+dep[q[i].t]-(dep[q[i].w]<<1);
		add(rt[q[i].s][0],1,n,dep[q[i].s],1);
		add(rt[q[i].w][0],1,n,dep[q[i].s],-1);
		add(rt[q[i].t][1],-n,n<<1,q[i].l-dep[q[i].t],1);
		add(rt[f[q[i].w]][1],-n,n<<1,q[i].l-dep[q[i].t],-1);
	}
	work(1,0);
	for(int i=1;i<=n;i++) cout<<ans[i]<<" ";
	cout<<flush;
	return 0;
}

P3899 [湖南集训] 更为厉害 / 谈笑风生

简单分讨一下:

\(\alpha.~\) \(a\)\(b\) 的祖先

答案为 \(\sum_{dep[b]~\in~[dep[a]+1,dep[a]+k]~\land~b~\in~\operatorname{subtree}(a)} siz[b]-1\)

线段树合并:每个节点维护一棵线段树即可,预处理时用线段树合并维护子树信息。

思考:别的做法?

\(\red{p.s.}\) 这里还可以离线 \(\text{dfs}\) 序上二维数点\(\text{dsu on tree}\),有人想听再讲吧。
二维数点
显然合法的 \(b\) 要满足 \(dep[b]\in [dep[a]+1,dep[a]+k]\)
这时设 \(rnk[i]\)\(i\)\(\text{dfs}\) 序中的位置,那么有:
\(rnk[b]\in [rnk[a]+1,rnk[a]+siz[a]-1]\)
离线下来二维数点即可。

\(\bf{dsu~on~tree}\)
同样能做,但需要加个线段树 / 树状数组

\(\beta.~\) \(b\)\(a\) 的祖先

答案为 \(\min(dep[a]-1,k)\times(siz[a]-1)\),直接算即可。

// godmoo's code
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int N=3e5+5;
const int V=6e7+5;
const int MX=3e5;
int n,q,dep[N],siz[N];
vector<int> G[N]; 
int tot,rt[N],lc[V],rc[V];
ll sum[V];
void pushup(int u){sum[u]=sum[lc[u]]+sum[rc[u]];}
void add(int &u,int l,int r,int pos,int x){
	if(!u) u=++tot;
	if(l==r){sum[u]+=x;return;}
	int mid=l+r>>1;
	if(pos<=mid) add(lc[u],l,mid,pos,x);
	else add(rc[u],mid+1,r,pos,x);
	pushup(u);
}
ll query(int u,int l,int r,int ql,int qr){
	if(!u) return 0;
	if(ql<=l&&r<=qr) return sum[u];
	int mid=l+r>>1;ll res=0;
	if(ql<=mid) res+=query(lc[u],l,mid,ql,qr);
	if(qr>mid) res+=query(rc[u],mid+1,r,ql,qr);
	return res;
}
int merge(int u,int v,int l,int r){
	if(!u||!v) return u|v;
	int now=++tot;
	if(l==r){sum[now]=sum[u]+sum[v];return now;}
	int mid=l+r>>1;
	lc[now]=merge(lc[u],lc[v],l,mid);
	rc[now]=merge(rc[u],rc[v],mid+1,r);
	pushup(now);return now;
}
void dfs(int u,int fa){
	dep[u]=dep[fa]+1,siz[u]=1;
	for(int v:G[u]) if(v!=fa) dfs(v,u),siz[u]+=siz[v];
	add(rt[u]=++tot,1,MX,dep[u],siz[u]-1);
	for(int v:G[u]) if(v!=fa) rt[u]=merge(rt[u],rt[v],1,MX);
}
int main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>q;
	for(int i=1,u,v;i<n;i++) cin>>u>>v,G[u].pb(v),G[v].pb(u);
	dfs(1,0);
	for(int i=1,a,k;i<=q;i++){
		cin>>a>>k;
		cout<<1ll*min(dep[a]-1,k)*(siz[a]-1)+query(rt[a],1,MX,dep[a]+1,dep[a]+k)<<"\n";
	}
	cout<<flush;
	return 0;
}

P7815 「小窝 R3」自傷無色

线段树合并。

\(\bf{I.}\) 切入点

  • 考虑对于 \((u,v)\),记 \(w=\operatorname{lca}(u,v)\)\(a=\operatorname{dis}(u,w)\)\(b=\operatorname{dis}(v,w)\),不失一般性,令 \(a\ge b\)

  • 那么对于合法的 \(x\) 显然有 \(a-b\lt x\lt a+b\),易知合法的树三角有 \(2b-1\) 个,它们的大小总和为 \((2b-1)(a+b+a)=4ab+2b^2-2a-b\)

  • 接着考虑怎么进行统计,发现直接枚举 \((u,v)\) 非常不现实,于是转而枚举 \(w\)

  • 套路的,只需保证 \(u\)\(v\)\(w\) 的不同子树内,就有 \(w=\operatorname{lca}(u,v)\) 了。

\(\bf{II.}\) 套路部分

  • 发现现在这变成了一个子树间的信息合并问题,考虑线段树合并,那接下来我们就得弄明白两个问题:何时进行合并怎么合并

  • 前几步非常显然,记录根节点到每个节点的距离 \(dep_u\),接着对于每个节点,将其子树内的节点的 \(dep\) 扔到线段树内维护(具体要维护什么呢?这个要由待会推式子的时候才能知道)。

  • 再重新下定义:记 \(a=dep_u\)\(b=dep_v\)\(c=dep_w\),不失一般性,令 \(a\ge b\)。那么原来的 \(a\) 就会变成现在的 \((a-c)\),对于 \(b\) 同理(方便后面明确)。

  • 接下来考虑子树内信息的合并。假设说已经将部分子树的信息合并到了当前的 \(w\) 的线段树 \(T_w\) 上,接下来将要合并子树 \(T_{son}\)怎么计算贡献?

  • 借助 \(\text{cdq}\) 分治的思想,当合并到非叶子节点时,计算左半的部分对另一棵子树右半部分的贡献,对于左右儿子内的贡献,递归处理即可;边界是到叶子节点时也要计算贡献。

  • 代码大概长这样:

int merge (int u, int v, int l, int r) {
	if (!u || !v) return u | v;
	if (l == r) {
		solve(u, v); 
		// merge u and v
		return u;
	}
	int mid = l + r >> 1;
	solve(lc[u], rc[v]), solve(lc[v], rc[u]);
	lc[u] = merge(lc[u], lc[v], l, mid);
	rc[u] = merge(rc[u], rc[v], mid + 1, r);
	pushup(u);
	return u;
}

\(\bf{III.}\) 计算贡献

  • 接下来我们考虑计算贡献,即 solve() 函数怎么写。换言之,我们要研究一棵值域较小的线段(子)树 \(T_x\) 和一棵值域较大的线段(子)树 \(T_y\) 怎么计算贡献。

  • 我们由前面重新下定义的部分,易知一对 \((u,v)\) 会产生的合法树三角个数为,

\[cnt=2(b-c)-1=2b-2c-1 \]

  • 它们的大小总和为,

\[\begin{align*} sum&=4(a-c)(b-c)+2(b-c)^2-2(a-c)-(b-c) \\\\ &=4ab-(4c+2)a-(8c+1)b+6c^2+3c+2b^2 \end{align*}\]

  • (下文的所有 \(k\in T\) 都指代 \(k\)\(T\)叶子节点。)

  • 那么会产生合法树三角个数为,

\[\begin{align*} cnt&=\sum_{b\in T_x}\sum_{a\in T_y}2b-2c-1 \\\\ &=2\sum_{a\in T_y}1\cdot\sum_{b \in T_x}b-(2c+1)\sum_{b\in T_x}1\cdot\sum_{a\in T_y}1 \end{align*}\]

  • 它们的总和为,

\[\begin{align*} sum&=\sum_{b\in T_x}\sum_{a\in T_y}4ab-(4c+2)a-(8c+1)b+6c^2+3c+2b^2 \\\\ &=4\sum_{b\in T_x}b\cdot\sum_{a\in T_y}a-(4c+2)\sum_{a\in T_y}a\cdot\sum_{b\in T_x}1-(8c+1)\sum_{b\in T_x}b\cdot\sum_{a\in T_y}1+(6c^2+3c)\sum_{b\in T_x}1\cdot\sum_{a\in T_y}1+2\sum_{b\in T_x}b^2\cdot\sum_{a\in T_y}1 \end{align*}\]

  • 于是我们在线段树上维护零次和一次和二次和就可以快速统计答案了。

  • 注意值域可以到 \(10^{14}\),离散化一下即可。

  • 时间复杂度 \(O(n\log n)\)

\(\red{p.s.}\) 其实这道题 \(\text{dsu on tree}\) 也能 \(O(n\log^2n)\) 做,但同样需要一个这样的数据结构,所以线段树合并可能更自然一点。

不写取模整数类会死得很惨。。。

// godmoo's code
#include<bits/stdc++.h>
#define pb push_back
#define is insert
#define fi first
#define se second
#define mkp make_pair
#define INF INT_MAX
#define mathmod(a,m) (((a)%(m)+(m))%(m))
#define mem(a,b) memset(a,b,sizeof a)
#define cpy(a,b) memcpy(a,b,sizeof b)
//#define int ll
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int N=1e5+5;
const int V=6e6+5;
const int MOD=1e9+7;
int n,m;
ll dis[N],tmp[N];
int tot,head[N];
struct Edge{int to,nxt,w;}edge[N<<1];
void add(int u,int v,int w){
	edge[++tot]={v,head[u],w};
	head[u]=tot;
}
void dfs(int u,int fa){
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to;
		if(v==fa) continue;
		tmp[v]=dis[v]=dis[u]+edge[i].w;
		dfs(v,u);
	}
}
int lsh(ll x){return lower_bound(tmp+1,tmp+m+1,x)-tmp;}
class ModInt{
	int x;
public:
	int val(){return x;}
	ModInt operator +(ModInt _) const{return ModInt(x+_.x);}
	ModInt operator -(ModInt _) const{return ModInt(x-_.x+MOD);}
	ModInt operator *(ModInt _) const{return ModInt(1ll*x*_.x);}
	void operator =(ll _){x=_;}
	ModInt(ll _=0ll):x(_%MOD){}
	ModInt(const ModInt& _):x(_.x){}
};
ModInt cnt,sum,c;
int cur,rt[N],lc[V],rc[V];
ModInt s0[V],s1[V],s2[V];
void pushup(int u){
	s0[u]=s0[lc[u]]+s0[rc[u]];
	s1[u]=s1[lc[u]]+s1[rc[u]];
	s2[u]=s2[lc[u]]+s2[rc[u]];
}
void add(int &u,int l,int r,int pos,int x){
	if(!u) u=++cur;
	if(l==r){
		s0[u]=s0[u]+1;
		s1[u]=s1[u]+x;
		s2[u]=s2[u]+1ll*x*x;
		return; 
	}
	int mid=l+r>>1;
	if(pos<=mid) add(lc[u],l,mid,pos,x);
	else add(rc[u],mid+1,r,pos,x);
	pushup(u);
}
void solve(int x,int y){
	if(!s0[x].val()||!s0[y].val()) return;
	ModInt x1y0=s1[x]*s0[y],x0y0=s0[x]*s0[y],x1y1=s1[x]*s1[y],x0y1=s0[x]*s1[y],x2y0=s2[x]*s0[y];
	cnt=cnt+x1y0*2-(c*2+1)*x0y0;
	sum=sum+x1y1*4-(c*4+2)*x0y1-(c*8+1)*x1y0+(c*c*6+c*3)*x0y0+x2y0*2;
}
int merge(int u,int v,int l,int r){
	if(!u||!v) return u|v;
	if(l==r){
		solve(u,v);
		s0[u]=s0[u]+s0[v];
		s1[u]=s1[u]+s1[v];
		s2[u]=s2[u]+s2[v];
		return u;
	}
	int mid=l+r>>1;
	solve(lc[u],rc[v]),solve(lc[v],rc[u]);
	lc[u]=merge(lc[u],lc[v],l,mid);
	rc[u]=merge(rc[u],rc[v],mid+1,r);
	pushup(u);
	return u;
}
void work(int u,int fa){
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to;
		if(v==fa) continue;
		work(v,u);
		c=ModInt(dis[u]);
		rt[u]=merge(rt[u],rt[v],1,m);
	}
	add(rt[u],1,m,lsh(dis[u]),dis[u]%MOD);
}
ll qpow(ll a,ll b=MOD-2,ll m=MOD){
	ll res=1;
	while(b){
		if(b&1) res=res*a%m;
		a=a*a%m,b>>=1;
	}
	return res;
}
signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;
	for(int i=1,u,v,w;i<n;i++) cin>>u>>v>>w,add(u,v,w),add(v,u,w);
	dfs(1,0),sort(tmp+1,tmp+n+1),m=unique(tmp+1,tmp+n+1)-tmp-1;
	work(1,0);
	cout<<1ll*sum.val()*qpow(cnt.val())%MOD<<endl; 
	return 0;
}

线段树合并 - 区间修改版\(^{[1]}\)\(^{[3]}\)

我们来思考一个东西,怎么在线段树合并内兼容懒标记的操作呢?

对于一些特殊情况,我们可以采用标记永久化的技术,直接不下放懒标记,合并时合并懒标记即可。

当然这个方法是有局限性的,对于更复杂的标记,我们仍考虑怎么进行 \(\text{pushdown}\)

显然一个点的懒标记仅与它所处的线段树有关,与其它线段树无关,所以我们不妨在合并前下放懒标记,同时在区间修改时,下放懒标记前先创建节点,防止信息丢失。

那么显然一个节点的两个儿子一定是被同时建出来的,所以对于两儿子的节点,继续往下合并;而对于没有儿子的节点,我们就尝试着合并懒标记到另外一棵线段树上,把这棵线段树子树作为结果返回即可。

显然这样的时间复杂度均摊仍为 \(O(\log n)\),实现:

int merge (int u, int v, int l, int r) {
	if (!u || !v) return u | v;
	if (l == r) {
		sum[u] += sum[v];
		return u;
	}
	if (!lc[u] && !rc[u]) {
		sum[v] += sum[u];
		lzy[v] += lzy[u];
		return v;
	}
	if (!lc[v] && !rc[v]) {
		sum[u] += sum[v];
		lzy[u] += lzy[v];
		return u;
	}
	pushdown(u), pushdown(v);
	int mid = l + r >> 1;
	lc[u] = merge(lc[u], lc[v], l, mid);
	rc[u] = merge(rc[u], rc[v], mid + 1, r);
	pushup(u);
	return u;
}

现在我们就可以愉快地在均摊 \(O(\log n)\) 的复杂度下合并带区间修改的信息了 o((>ω< ))o 。

P2495 [SDOI2011] 消耗战

首先,当仅有一次询问时,容易想到树形 \(\text{dp}\)

\(f_u\) 为以 \(u\) 为根的子树内,使该子树内所有关键点无法到达的最小代价。

容易写出状态转移方程:

\[f_u=\begin{cases} \text{INF} & u~为关键点\\ \sum_{v\in son_u}\min(f_u,w(u,v)) & otherwise\\ \end{cases}\]

显然每次暴力是 \(O(nm)\)

发现这个东西形式较简单,且每次询问只需要考虑关键点的变化所带来的影响,于是离线下来,对每个节点 \(u\) 开一个线段树,第 \(i\) 个叶子存储第 \(i\) 次询问的 \(f_u\),这棵树是 \(\bold{leafy}\) 的。

对于每个 \(u\),枚举它的儿子 \(v\),首先 \(v\) 的线段树全局与 \(w(u,v)\)\(\min\),然后两棵线段树再合并即可。最后记得若 \(u\) 在某个询问中为关键点,就把对应位置的值单点修改为 \(\text{INF}\),这样就完成了对 \(u\) 的转移,没了。

虚树:这玩意貌似比我好写?!

// godmoo's code
#include<bits/stdc++.h>
#define pb push_back
#define is insert
#define fi first
#define se second
#define mkp make_pair
#define INF 2e18
#define mathmod(a,m) (((a)%(m)+(m))%(m))
#define mem(a,b) memset(a,b,sizeof a)
#define cpy(a,b) memcpy(a,b,sizeof b)
//#define int long long 
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int N=2.5e5+5;
const int M=5e5+5;
const int V=1.5e7+5;
int n,m;
int cur,head[N];
struct Edge{int to,nxt,w;}edge[N<<1];
void add(int u,int v,int w){
	edge[++cur]={v,head[u],w};
	head[u]=cur;
}
vector<int> q[N];
int tot,rt[N],lc[V],rc[V];
ll val[V]; // 只有叶子节点的值是必要的,其他节点仅用于暂存 lzy,所以可以开成 1 个
void pushtag(int u,ll x){val[u]=min(val[u],x);}
void pushdown(int u){
	if(val[u]==INF) return;
	pushtag(lc[u],val[u]),pushtag(rc[u],val[u]);
	val[u]=INF;
}
int create(){
	int u=++tot;
	val[u]=INF;
	return u;
}
void build(int n){for(int i=1;i<=n;i++) rt[i]=create();}
void change(int &u,int l,int r,int pos,ll x){ 
	if(!u) u=create();
	if(l==r){val[u]=INF;return;} 
	pushdown(u);
	int mid=l+r>>1;
	if(pos<=mid) change(lc[u],l,mid,pos,x);
	else change(rc[u],mid+1,r,pos,x);
}
int merge(int u,int v,int l,int r){
	if(!u||!v) return u|v;
	if(!lc[u]&&!rc[u]){val[u]+=val[v];return u;} //	这道题中 lzy 的本质是还没有传下去的全局取 min 值 
//	if(!lc[v]&&!rc[v]){val[v]+=val[u];return v;} // 这里不要改变原树结构,方便输出 
	pushdown(u),pushdown(v);
	int mid=l+r>>1;
	lc[u]=merge(lc[u],lc[v],l,mid);
	rc[u]=merge(rc[u],rc[v],mid+1,r);
	return u;
}
void dfs(int u,int fa){
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to,w=edge[i].w;
		if(v==fa) continue;
		dfs(v,u);
		pushtag(rt[v],w);
		rt[u]=merge(rt[u],rt[v],1,m);
	}
	for(int x:q[u]) change(rt[u],1,m,x,INF);
	// 叶子节点为 0,合并直接没有
}
void output(int u,int l,int r){
	if(l==r){cout<<val[u]<<"\n";return;}
	if(!lc[u]) lc[u]=create();
	if(!rc[u]) rc[u]=create();
	pushdown(u);
	int mid=l+r>>1;
	output(lc[u],l,mid);
	output(rc[u],mid+1,r);
}
signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;
	for(int i=1,u,v,w;i<n;i++) cin>>u>>v>>w,add(u,v,w),add(v,u,w);
	cin>>m;
	for(int i=1,k,x;i<=m;i++){
		cin>>k;
		while(k--) cin>>x,q[x].pb(i);
	}
	dfs(1,0);
	output(rt[1],1,m);
	return 0;
}

P10834 [COTS 2023] 题 Zadatak

显然描述一个这样的图案只需其主对角线的前一半,换句话说,只需维护主对角线上这 \(\frac{n}{2}\) 个位置是黑还是白,我们就可以画出整个图案。

扔到线段树内维护,下标从主对角线的第 \(\frac{n}{2}\) 个位置到第 \(1\) 个位置递增,黑色为 \(1\),白色为 \(0\)。对于每个节点 \(u\),若它表示的范围为 \([l_u,r_u]\),那么它整块的面积为 \(sum_u=[r_u^2-(l_u-1)^2]\),设黑色部分面积为 \(val_u\),再维护反转标记 \(lzy_u\)\(pushup\)\(pushdown\) 是平凡的。

一开始所有节点都为黑色,直接全局打 \(lzy\) 标记。

对于叶子节点的合并,新的 \(sum\)\(sum_u\oplus sum_v\),代价为 \([sum_u\gt0~\land~sum_v\gt0]\times len_u\),直接返回就行;非叶子节点递归合并即可。

同时,对于没有儿子的区间,直接将懒标记合并到另一棵树上去,这部分是容易的。

记得根节点个数开两倍。

// godmoo's code
#include<bits/stdc++.h>
#define pb push_back
#define is insert
#define fi first
#define se second
#define mkp make_pair
#define INF INT_MAX
#define mathmod(a,m) (((a)%(m)+(m))%(m))
#define mem(a,b) memset(a,b,sizeof a)
#define cpy(a,b) memcpy(a,b,sizeof b)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
const int N=1e5+5;
const int V=1e7+5;
const int A=5e5;
int n,a[N];
int tot,rt[N<<1],lc[V],rc[V],lzy[V];
ll sum[V],val[V];
void pushup(int u){
	val[u]=val[lc[u]]+val[rc[u]];
}
void pushtag(int u){
	val[u]=sum[u]-val[u];
	lzy[u]^=1;
}
void pushdown(int u){
	if(!lzy[u]) return;
	pushtag(lc[u]),pushtag(rc[u]);
	lzy[u]=0;
}
int create(int l,int r){
	int u=++tot;
	sum[u]=1ll*r*r-1ll*(l-1)*(l-1);
	return u;
}
void modify(int &u,int l,int r,int ql,int qr){
	if(!u) u=create(l,r);
	if(ql<=l&&r<=qr){pushtag(u);return;}
	int mid=l+r>>1;
	if(!lc[u]) lc[u]=create(l,mid);
	if(!rc[u]) rc[u]=create(mid+1,r);
	pushdown(u);
	if(ql<=mid) modify(lc[u],l,mid,ql,qr);
	if(qr>mid) modify(rc[u],mid+1,r,ql,qr);
	pushup(u);
}
pil merge(int u,int v,int l,int r){
	if(!val[u]) return {v,0};
	if(!val[v]) return {u,0};
	if(val[u]==sum[u]){ll res=val[v];pushtag(v);return {v,res};}
	if(val[v]==sum[v]){ll res=val[u];pushtag(u);return {u,res};}
	if(l==r){int res=(val[u]&&val[v])*sum[u];val[u]^=val[v];return {u,res};}
	ll mid=l+r>>1;
	pushdown(u),pushdown(v);
	ll res1,res2;
	tie(lc[u],res1)=merge(lc[u],lc[v],l,mid);
	tie(rc[u],res2)=merge(rc[u],rc[v],mid+1,r);
	pushup(u);
	return {u,res1+res2};
}
int main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;
	for(int i=1;i<=n;i++) cin>>a[i],a[i]>>=1,modify(rt[i],1,A,1,a[i]);
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		auto tmp=merge(rt[u],rt[v],1,A);
		rt[i+n]=tmp.fi;
		cout<<(tmp.se<<2ll)<<"\n";
	}
	cout<<flush;
	return 0;
}

P7880 [Ynoi2006] rldcot

压轴题,但跟前文“线段树合并-区间修改版”关系不大,水 Ynoi

显然,现在可能产生贡献的点对有 \(O(n^2)\) 个,我们用 \(O(n\log n)-O(1)\)\(\text{ST}\) 表将点对的表示为 \(\{u,v,\operatorname{dep}(\operatorname{lca}(u,v))\}\) 的形式,然后离散化一下颜色就可以离线下来做扫描线了。

\(\red{p.s.}\)(会扫描线的可以跳过这一段)
具体的,所有询问挂到左端点上,从右往左扫。设扫到 \(x\),维护 \(pos[i]\) 表示颜色 \(i\) 左端点 \(\ge x\) 的所有点对的右端点的最小值(这是平凡的),那么区间 \([x,r]\) 的答案就是当前满足 \(pos[i]\le r\) 的颜色个数,用值域树状数组维护即可,时间复杂度 \(O(n^2\log n+q\log n)\)

发现瓶颈在点对数量上,考虑优化点对数量。换句话说:有没有一些点对是绝对冗余的?能不能通过找到一些非冗余点对的必要条件,以有效减少点对数量?

注意如果有四个点 \(a,b,c,d~\) 满足 \([a,b]\in[c,d]\),且 \(\operatorname{lca}(a,b)=\operatorname{lca}(c,d)=u\),那么点对 \((c,d)\) 必然是冗余的。

这启发我们,对于节点 \(u\)\(a\in\operatorname{subtree}(u)\),只有 \(a\)\(\operatorname{subtree}(u)\setminus\operatorname{subtree}(a)\) 中的节点编号意义上的前驱 \(b_1\) 和后继 \(b_2\) 才可能与 \(a\) 构成非冗余点对。

\(\red{p.s.}\) 显然,证明略。注意这里的前驱和后继是非严格的\((u,u)\) 当然是一个可能的非冗余点对,但这样并不是很好写代码……所以我们写代码时可以采用严格的前驱和后继,最后再把所有 \((u,u)\) 这样的点对扔进去,这样只会再增加 \(O(n)\) 个点对。

考虑 \(\bold{dsu~on~tree}\) 的过程,维护一个 std::set 表示上文的 \(\operatorname{subtree}(u)\setminus\operatorname{subtree}(a)\) 集合,每次从重儿子继承 std::set,依次让每一棵轻子树的每个节点在里面找前驱和后继,再把它们所有一起扔进 std::set

由于前驱和后继是逆运算,所以这样是正确的。

虽然轻儿子要遍历三遍,但由前面的时间复杂度分析,这个 \(\text{dsu~on~tree}\) 还是 \(n\log n\) 的;又由于点对仅在每次遍历时产生 \(O(1)\) 对,所以点对数量也是 \(O(n\log n)\) 的。

于是用前面扫描线的方式统计答案,总时间复杂度 \(O(n\log^2n+q\log n)\)

// godmoo's code
#include<bits/stdc++.h>
#define ep emplace
#define eb emplace_back
#define fi first
#define se second
#define mkp make_pair
#define INF INT_MAX
#define mathmod(a,m) (((a)%(m)+(m))%(m))
#define mem(a,b) memset(a,b,sizeof a)
#define cpy(a,b) memcpy(a,b,sizeof b)
using namespace std;
namespace FastIO{
	const int MX=1<<20;
	#ifdef ONLINE_JUDGE
	char buf[MX],*p1,*p2;
	#define gc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,MX,stdin),p1==p2)?0:*p1++)
	#else
	#define gc() getchar()
	#endif
	char pbuf[MX],*p3=pbuf;
	inline void pc(const char c){if(p3-pbuf==MX)fwrite(pbuf,1,MX,stdout),p3=pbuf;*p3++=c;}
	struct Flusher{~Flusher(){fwrite(pbuf,1,p3-pbuf,stdout);};}flusher;
	template<class T> inline void rd(T& p){
		T x=0,f=1;char c=gc();
		while(c<'0'||c>'9'){if(c=='-')f=0;c=gc();}
		while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=gc();
		p=f?x:-x;
	}
	template<class T,class ... Args> inline void rd(T& p,Args&... args){rd(p),rd(args...);}
	template<class T> inline void wr(T x){
		if(!x){pc('0');return;}
		static short sta[40];short tp=0;
		if(x<0) pc('-'),x=-x;
		do{sta[tp++]=x%10,x/=10;}while(x);
		while(tp) pc(sta[--tp]+'0');
	}
	template< > inline void wr(const char x){pc(x);}
	template< > inline void wr(const char* x){for(int i=0;x[i];i++)pc(x[i]);}
	template< > inline void wr(const string x){for(char c:x)pc(c);}
	template<class T,class ... Args> inline void wr(T x,Args... args){wr(x),wr(args...);}
}
using FastIO::rd; using FastIO::wr;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int N=1e5+5;
const int M=5e5+5;
int n,m,ans[M];
int cur,head[N];
struct Edge{int to,nxt,w;}edge[N<<1];
void add(int u,int v,int w){
	edge[++cur]={v,head[u],w};
	head[u]=cur;
}
int tot,siz[N],son[N];
ll dep[N],tmp[N];
set<int> s;
vector<pii> vec[N];
void dfs(int u,int fa){
	siz[u]=1;
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to;
		if(v==fa) continue;
		tmp[v]=dep[v]=dep[u]+edge[i].w;
		dfs(v,u);
		siz[u]+=siz[v];
		if(siz[v]>siz[son[u]]) son[u]=v;
	}
}
int pre(int x){
	auto it=s.lower_bound(x);
	return it!=s.begin()?*prev(it):0;
}
int nxt(int x){
	auto it=s.upper_bound(x);
	return it!=s.end()?*it:0;
}
void find(int u,int col){
	int p=pre(u),n=nxt(u);
	if(p) vec[p].eb(u,col);
	if(n) vec[u].eb(n,col);
}
void upd(int u,int fa,int col){
	find(u,col);
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to;
		if(v!=fa) upd(v,u,col);
	}
}
void add(int u,int fa){ 
	s.ep(u);
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to;
		if(v!=fa) add(v,u);
	}
}
void dsu(int u,int fa,int op){
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to;
		if(v!=fa&&v!=son[u]) dsu(v,u,0);
	}
	if(son[u]) dsu(son[u],u,1);
	find(u,dep[u]),s.ep(u);
	for(int i=head[u];i;i=edge[i].nxt){
		int v=edge[i].to;
		if(v!=fa&&v!=son[u]) upd(v,u,dep[u]),add(v,u);
	}
	if(!op) s.clear();
}
int pos[N]; // pos[i]:左点 >= i 时右点最小值 
vector<pii> q[N];
struct BIT{
	int tr[N];
	#define lb(x) ((x)&(-(x)))
	void add(int u,int x){for(;u<=n;u+=lb(u))tr[u]+=x;}
	int query(int u){int res=0;for(;u;u-=lb(u))res+=tr[u];return res;}
}T;
int main(){
	rd(n,m);
	for(int i=1,u,v,w;i<n;i++) rd(u,v,w),add(u,v,w),add(v,u,w);
	dfs(1,0);
	sort(tmp+1,tmp+n+1);
	tot=unique(tmp+1,tmp+n+1)-tmp-1;
	for(int i=1;i<=n;i++) dep[i]=lower_bound(tmp+1,tmp+n+1,dep[i])-tmp;
	dsu(1,0,1);
	for(int i=1;i<=n;i++) vec[i].eb(i,dep[i]);
	for(int i=1,l,r;i<=m;i++) rd(l,r),q[l].eb(r,i);
	for(int i=1;i<=n;i++) pos[i]=n+1;
	for(int l,r=n,c;r;r--){
		for(auto p:vec[r]){
			tie(l,c)=p;
			T.add(pos[c],-1);
			pos[c]=min(pos[c],l);
			T.add(pos[c],1);
		}
		for(auto p:q[r]) ans[p.se]=T.query(p.fi);
	}
	for(int i=1;i<=m;i++) wr(ans[i],'\n');
	return 0;
}

参考资料

完结撒花 ~

posted @   godmoo  阅读(5)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示