题解 BZOJ4543【[POI2014] HOT-Hotels】

长链剖分优化 DP 板子题了,但是虽然是板子这个转移方程也很难想。

problem

树。求 \(\sum_{1\leq i<j<k\leq n}[dist(i,j)=dist(i,k)=dist(j,k)].\)\(n\leq 10^5\)

solution \(O(n^2)\)

枚举每个点作为中心,以他为根搜索,然后相当于是它的每一个子树中选一个,一共选三个,是一个简单背包。

solution

与重链剖分相似,长链剖分是将树剖成很多条长链。

我们定义长儿子,为一个点的儿子中子树深度最大的一个儿子。或者这样定义:

  • 给每个点一个 \(height_u\)。叶子节点的 \(height\)\(1\)
  • 则其它所有点的 \(height_u\) 定义为 \(\max_{v\in son}\{height_v\}+1\),并记录取到 \(height_u\) 的儿子 \(v\) 为长儿子(保留一个)
  • 长链就是一条到叶节点的链,满足链上任意一个父亲的儿子是它的长儿子。

有几个性质值得留意:

  • 所有长链长度总和为 \(O(n)\),因为每一个点都只在一个长链中。
  • 任意一个节点 \(x\)\(k\) 级祖先 \(y\) 所在长链的长度一定大于等于 k,否则 \(y\) 所在长链应该接到 \(x\) 上去。
  • 从一个节点开始向上跳长链,最多跳 \(O(\sqrt{n})\) 次。因为每次跳到的长链长度一定是单调递增的。

长链剖分可以 \(O(n\log n)-O(1)\) 求 LCA、LA。

以下缩写 \(height_u\)\(len_u\),如果 \(u\) 是长链顶,那么 \(len_u\) 是这条长链的点数。另外还有不等式 \(len_u\geq len_v+1\),证明请参见 \(len_u\) 的定义。


考虑这个题的 DP 怎么写,这里直接给结论了。

  • \(f(u,d)\) 表示点 \(u\) 的子树中离它距离为 \(d\) 的点个数。根据定义,\(0\leq d<len_u\),边界 \(f(u,0)=1\) 即它自己。
  • \(f(u,d)=\sum_v f(v,d-1)\)
  • \(g(u,d)\) 表示,点 \(u\) 的子树中,有多少点对 \((i,j)\),使得 \(dist(i,k)=dist(j,k)=dist(k,u)+d\),其中 \(k=lca(i,j)\),感性的理解就是在 \(u\) 顶上接长度(边)为 \(d\) 的链后有多少个答案。同样的 \(0\leq d<len_u\)
  • \(g(u,d)\) 的第一种转移来自它的所有儿子的 \(g(v,d+1)\) 之和。
  • 考虑 [一个儿子 \(v\) 的子树] 和 [\(v\) 加入前 \(u\) 维护的子树] 之间产生的答案,则 \(g(u,d):=g(u,d)+f(u,d)f(v,d-1)\)
  • 考虑答案时,也是考虑 [一个儿子 \(v\) 的子树] 和 [\(v\) 加入前 \(u\) 维护的子树] 之间的贡献。当 [\(v\) 加入前 \(u\) 维护的子树] 为空时,答案加上 \(g(v,1)\)。反之,要么是 \(v\) 选两个,\(u\) 选一个,要么是 \(v\) 选一个,\(u\) 选两个,分别是 \(g(v,d+1)f(u,d)\)\(f(v,d-1)g(u,d)\)

到这里应该可以写一个 \(O(n^2)\) DP。考虑优化,因为这里的 \(f,g\) 都和深度有关,所以长链剖分优化,就是先 DP 长儿子,继承长儿子的信息,然后 DP 轻儿子,在合并轻儿子信息时的复杂度为 \(O(len_v)\),然后我们知道每条长链都只在头上被暴力,所有长链之和为 \(O(n)\),所以整个题就 \(O(n)\)

看上去非常简单是吧!其实很难写,这里介绍指针写法,首先可能需要手写一下内存池,写法可供借鉴:

LL _[1<<19],*_pos=_+sizeof(_)/sizeof(LL);
LL* alloc(int sz){return _pos-=sz;}

然后考虑:在我 DP 长儿子的时候,我通过使得 f[son[u]]=f[u]+1,使得长儿子 DP 完后,原来的 f[son[u]][0] 变成 f[u][1](可能有点抽象,建议画出内存条),同理 g[son[u]]=g[u]-1。然后 DP 和刚才上面一样(注意转移顺序,\(ans\to g\to f\))。注意数组越界的问题,我们可以看一下正常写出来怎么样:

for(int i=0;i<hei[v];i++){
	ans+=h[u][i+1]*f[v][i];
	ans+=f[u][i-1]*h[v][i];
}
for(int i=0;i<hei[v];i++){
	h[u][i-1]+=h[v][i];
	h[u][i+1]+=f[u][i+1]*f[v][i];
	f[u][i+1]+=f[v][i];
}

这里注意如果我们枚举 \(0\leq i<len_v\) 的话,访问 \(f(u,i+1)\) 是没有问题的,因为 \(len_u\geq len_v+1\);但是 \(i-1\) 就有问题,我们要判断一下是否有 \(i\geq 1\)

然后看一眼空间问题,\(f\) 只用开一倍空间是没问题的,但是我们要看一下 \(g\),结论是 \(g\) 要开两倍空间,向前 \(len_u\) 向后 \(len_u\),这是因为 g[son[u]]=g[u]-1,所以向前开 \(len_u\);又因为 \(g(u,d)\)\(d\) 的定义域为 \(0\leq d<len_u\),顶上的 \(u\) 向后开 \(len_u\);所以就是这样开数组就够了。(其实这个东西每次传上来的时候会舍弃上一位的 \(0\),然后多算一个 \(len_u\)

那么这题就结束了。

Code

点击查看代码

Rename \(height,len\to hei\)\(g\to h\)


#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
int n,hei[1<<17],son[1<<17];
vector<int> g[1<<17];
LL _[1<<20],*_pos=_+sizeof(_)/sizeof(LL);
LL* alloc(int sz){return _pos-=sz;}
LL*f[1<<17],*h[1<<17],ans=0;
void dfs(int u,int fa){
	for(int v:g[u]) if(v!=fa){
		dfs(v,u);
		if(hei[son[u]]<hei[v]) son[u]=v;
	}
	hei[u]=hei[son[u]]+1;
}
void dp(int u,int fa){
	if(son[u]) f[son[u]]=f[u]+1,h[son[u]]=h[u]-1,dp(son[u],u);
	f[u][0]++;
	ans+=h[u][0],debug("tr0,u=%d,ans+=%lld\n",u,h[u][0]);
	for(int v:g[u]) if(v!=fa&&v!=son[u]){
		f[v]=alloc(hei[v]),h[v]=alloc(hei[v]<<1)+hei[v],dp(v,u);
		for(int i=0;i<hei[v];i++){
			//hei[u]>=hei[v]+1
			ans+=h[u][i+1]*f[v][i],debug("tr1,u=%d,v=%d,i=%d,ans+=%lld*%lld\n",u,v,i,h[u][i+1],f[v][i]);
			if(i>=1) ans+=f[u][i-1]*h[v][i],debug("tr2,u=%d,v=%d,i=%d,ans+=%lld*%lld\n",u,v,i,f[u][i-1],h[v][i]);
		}
		for(int i=0;i<hei[v];i++){
			if(i>=1) h[u][i-1]+=h[v][i];
			h[u][i+1]+=f[u][i+1]*f[v][i];
			f[u][i+1]+=f[v][i];
		}
	}
	debug("son[%d]=%d,hei[%d]=%d\n",u,son[u],u,hei[u]);
	debug("f:"); for(int i=0;i<hei[u];i++) debug("%lld,",f[u][i]); debug("\n");
	debug("h:"); for(int i=0;i<hei[u];i++) debug("%lld,",h[u][i]); debug("\n");
}
int main(){
	scanf("%d",&n);
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),g[u].push_back(v),g[v].push_back(u);
	dfs(1,0);
	f[1]=alloc(hei[1]),h[1]=alloc(hei[1]<<1)+hei[1],dp(1,0);
	printf("%lld\n",ans);
	return 0;
}

posted @ 2023-07-25 21:56  caijianhong  阅读(43)  评论(0编辑  收藏  举报