P3241 [HNOI2015]开店

点分树小结

点分树的题其实就是把点分治的询问点换成某几个点,询问换成多组,有时还会带修改,那么我们会发现,每次如果做点分治,找重心的过程都是一样的,所以我们先处理出所有的重心,将其分成不同的层数,将所有重心逐层连起来,避免重复找重心,从而进行修改信息或回答多组询问

两个关键的性质:

  • 点分树最大深度是\(log\)级别的

  • 点分树上两点的\(lca\)在原树上两点之间的路径上

因此我们可以利用这两个强力的性质,在点分树上用比较暴力的方法解决路径统计问题。

对于这道题

  • 首先假设没有\(l,r\)的限制,询问某个点到其他所有点的距离之和,我们用点分树去做该如何做。

便于理解,设三个数组

\(sum[0][x]\)表示点分树上\(x\)子树的所有点到x的距离和

\(sum[1][x]\)表示点分树上\(x\)子树的所有点到x点分树上父亲的距离和

\(siz[x]\)表示点分树上\(x\)子树大小

显然对于询问的点\(x\),我们跳点分树上的父亲即可,初始时\(ans\)设为\(sum[0][x]\),也就是点分树\(x\)子树下方的贡献和

跳点分树父亲时,当前跳到点\(now\)(点\(now\)要有父亲),\(ans += sum[0][fa[now]] - sum[1][now] + (siz[fa[now]] - siz[now]) * Dis(x, fa[now])\)

这是在跳的过程中加上\(x\)上方的贡献。

\(sum[0][fa[now]] - sum[1][now]\)\(fa\)的除掉\(now\)的子树的点到\(fa\)的距离和,然而这些贡献还少一段,就是\(fa\)\(x\)之间的一段,所以要加上后面的那部分。

解释一下那两个减法,其实都是除掉已经处理过的子树,\(siz\)可以做减法减去重复计算的,比较简单,注意\(sum\)需要开两个数组,\(sum[0][fa[now]] - sum[1][now]\)不能写成\(sum[0][fa[now]] - sum[0][now]\),其实挺显然的,因为这两个数组的对象不同,第一种写法的两个数组全是到\(fa\)的距离和,而第二种写法是到\(fa\)和到\(now\)的距离和,所以不能用第二种方法去重

  • 现在有了\(l,r\)的限制

其实就是再加一维(存成\(vector\)),存上子树中各点的信息就好了,按照年龄排序,距离做个前缀和,然后询问\(l,r\)也就相当于\([0,r] - [0,l-1]\),二分出\(vector\)\(l-1,r\)的位置即可。因为点分树每层节点的\(vector\)元素总和是\(n\),共\(log\)层,所以空间\(nlogn\),可以接受

然后还有个小优化,\(siz\)其实不用单独存个\(vector\),用二分后\(sum\)\(vector\)的下标之差就能算出\(siz\),另外我点分树一直写树剖\(lca\),感觉比\(ST\)表预处理要快一点

  • 总结:其实最主要的就是那个式子,初始\(ans = sum[0][x]\), 点分树跳父亲时, \(ans += sum[0][fa[now]] - sum[1][now] + (siz[fa[now]] - siz[now]) * Dis(x, fa[now])\) ,其他的都是套路了

代码

#include <cmath>
#include <ctime>
#include <vector>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define rint register int
using namespace std;
typedef long long ll;

const int maxn = 1.5e5 + 10;
const int inf = 0x3f3f3f3f;
int n, q, A, maxsiz, tsiz, rt, cnt, head[maxn], v[maxn], Fa[maxn];
int dep[maxn], dis[maxn], siz[maxn], son[maxn], top[maxn], fa[maxn];
bool vis[maxn];
ll lastans;

struct Edge {
	int to, nxt, val;
}e[maxn << 1];

struct Node {
	int v; ll sum;
	bool operator < (const Node &B) const {
		return v < B.v;
	}
};
vector < Node > vec[2][maxn];

template <typename T> T read(register T x = 0, register bool f = 0, register char ch = getchar()) {
	for(;!isdigit(ch);ch = getchar()) f = ch == '-';
	for(; isdigit(ch);ch = getchar()) x = (x << 3) + (x << 1) + (ch & 15);
	return f ? -x : x ;
}

void add(rint x, rint y, rint z) {
	e[++cnt] = (Edge){y, head[x], z}, head[x] = cnt;
}

void dfs1(rint x, rint prt) {
	fa[x] = prt, dep[x] = dep[prt] + 1, siz[x] = 1;
	for(rint i = head[x], y;i;i = e[i].nxt) {
		if((y = e[i].to) == prt) continue;
		dis[y] = dis[x] + e[i].val;
		dfs1(y, x);
		siz[x] += siz[y];
		if(!son[x] || siz[y] > siz[son[x]]) son[x] = y;
	}
}

void dfs2(rint x, rint tp) {
	top[x] = tp;
	if(son[x]) dfs2(son[x], tp);
	for(rint i = head[x], y;i;i = e[i].nxt) {
		if((y = e[i].to) != fa[x] && y != son[x]) dfs2(y, y);
	}
}

int Dis(rint a, rint b) {
	rint x = a, y = b;
	while(top[x] != top[y]) {
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		x = fa[top[x]];
	}
	x = dep[x] < dep[y] ? x : y;
	return dis[a] + dis[b] - dis[x] * 2;
}

void getrt(rint x, rint prt) {
	siz[x] = 1;
	rint maxs = 0;
	for(rint i = head[x], y;i;i = e[i].nxt) {
		if(vis[y = e[i].to] || y == prt) continue;
		getrt(y, x);
		siz[x] += siz[y];
		maxs = max(maxs, siz[y]);
	}
	maxs = max(maxs, tsiz - siz[x]);
	if(maxs < maxsiz) maxsiz = maxs, rt = x;
}

void dfs(rint x, rint prt, rint sum) {
	siz[x] = 1;
	vec[0][rt].push_back((Node){v[x], sum});
	if(Fa[rt]) vec[1][rt].push_back((Node){v[x], Dis(x, Fa[rt])});
	for(rint i = head[x], y;i;i = e[i].nxt) {
		if(vis[y = e[i].to] || y == prt) continue;
		dfs(y, x, sum + e[i].val);
		siz[x] += siz[y];
	}
}

void solve(rint x) {
	vis[x] = 1, dfs(x, 0, 0);
	for(rint i = head[x], y;i;i = e[i].nxt) {
		if(vis[y = e[i].to]) continue;
		tsiz = siz[y], maxsiz = inf, getrt(y, x);
		Fa[rt] = x;
		solve(rt);
	}
}

ll query(rint opt, rint x, rint l, rint r, ll &ss) {
	rint lef = lower_bound(vec[opt][x].begin(), vec[opt][x].end(), (Node){l, 0}) - vec[opt][x].begin() - 1;
	rint rig = upper_bound(vec[opt][x].begin(), vec[opt][x].end(), (Node){r, 0}) - vec[opt][x].begin() - 1;
	ss = rig - lef;
	ll ans = 0;
	if(rig >= 0 && rig < (int) vec[opt][x].size()) ans += vec[opt][x][rig].sum;
	if(lef >= 0 && lef < (int) vec[opt][x].size()) ans -= vec[opt][x][lef].sum;
	return ans;
}

int main() {
	n = read<int>(), q = read<int>(), A = read<int>();
	for(rint i = 1;i <= n; ++i) v[i] = read<int>();
	for(rint i = 1, x, y, z;i < n; ++i) {
		x = read<int>(), y = read<int>(), z = read<int>();
		add(x, y, z), add(y, x, z);
	}
	dfs1(1, 0), dfs2(1, 1);
	tsiz = n, maxsiz = inf, getrt(1, 0), solve(rt);
	for(rint i = 1;i <= n; ++i) {
		sort(vec[0][i].begin(), vec[0][i].end());
		sort(vec[1][i].begin(), vec[1][i].end());
		for(rint j = 1;j < (int)vec[0][i].size(); ++j) vec[0][i][j].sum += vec[0][i][j - 1].sum;
		for(rint j = 1;j < (int)vec[1][i].size(); ++j) vec[1][i][j].sum += vec[1][i][j - 1].sum;
	}
	for(rint i = 1, x, l, r;i <= q; ++i) {
		x = read<int>(), l = (read<ll>() + lastans) % A, r = (read<ll>() + lastans) % A;
		l > r ? swap(l, r) : void();
		ll s1, s2;
		lastans = query(0, x, l, r, s1);
		for(rint now = x;Fa[now];now = Fa[now]) {
			lastans += query(0, Fa[now], l, r, s2) - query(1, now, l, r, s1);
			lastans += (s2 - s1) * Dis(x, Fa[now]);
		}
		printf("%lld\n", lastans);
	}
	return 0;
}
posted @ 2021-01-14 21:45  liuzhaoxu  阅读(63)  评论(0编辑  收藏  举报