P3241 [HNOI2015]开店
点分树小结
点分树的题其实就是把点分治的询问点换成某几个点,询问换成多组,有时还会带修改,那么我们会发现,每次如果做点分治,找重心的过程都是一样的,所以我们先处理出所有的重心,将其分成不同的层数,将所有重心逐层连起来,避免重复找重心,从而进行修改信息或回答多组询问
两个关键的性质:
-
点分树最大深度是\(log\)级别的
-
点分树上两点的\(lca\)在原树上两点之间的路径上
因此我们可以利用这两个强力的性质,在点分树上用比较暴力的方法解决路径统计问题。
对于这道题
- 首先假设没有\(l,r\)的限制,询问某个点到其他所有点的距离之和,我们用点分树去做该如何做。
便于理解,设三个数组
\(sum[0][x]\)表示点分树上\(x\)子树的所有点到x的距离和
\(sum[1][x]\)表示点分树上\(x\)子树的所有点到x点分树上父亲的距离和
\(siz[x]\)表示点分树上\(x\)子树大小
显然对于询问的点\(x\),我们跳点分树上的父亲即可,初始时\(ans\)设为\(sum[0][x]\),也就是点分树\(x\)子树下方的贡献和
跳点分树父亲时,当前跳到点\(now\)(点\(now\)要有父亲),\(ans += sum[0][fa[now]] - sum[1][now] + (siz[fa[now]] - siz[now]) * Dis(x, fa[now])\)
这是在跳的过程中加上\(x\)上方的贡献。
\(sum[0][fa[now]] - sum[1][now]\)是\(fa\)的除掉\(now\)的子树的点到\(fa\)的距离和,然而这些贡献还少一段,就是\(fa\)到\(x\)之间的一段,所以要加上后面的那部分。
解释一下那两个减法,其实都是除掉已经处理过的子树,\(siz\)可以做减法减去重复计算的,比较简单,注意\(sum\)需要开两个数组,\(sum[0][fa[now]] - sum[1][now]\)不能写成\(sum[0][fa[now]] - sum[0][now]\),其实挺显然的,因为这两个数组的对象不同,第一种写法的两个数组全是到\(fa\)的距离和,而第二种写法是到\(fa\)和到\(now\)的距离和,所以不能用第二种方法去重
- 现在有了\(l,r\)的限制
其实就是再加一维(存成\(vector\)),存上子树中各点的信息就好了,按照年龄排序,距离做个前缀和,然后询问\(l,r\)也就相当于\([0,r] - [0,l-1]\),二分出\(vector\)中\(l-1,r\)的位置即可。因为点分树每层节点的\(vector\)元素总和是\(n\),共\(log\)层,所以空间\(nlogn\),可以接受
然后还有个小优化,\(siz\)其实不用单独存个\(vector\),用二分后\(sum\)的\(vector\)的下标之差就能算出\(siz\),另外我点分树一直写树剖\(lca\),感觉比\(ST\)表预处理要快一点
- 总结:其实最主要的就是那个式子,初始\(ans = sum[0][x]\), 点分树跳父亲时, \(ans += sum[0][fa[now]] - sum[1][now] + (siz[fa[now]] - siz[now]) * Dis(x, fa[now])\) ,其他的都是套路了
代码
#include <cmath>
#include <ctime>
#include <vector>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define rint register int
using namespace std;
typedef long long ll;
const int maxn = 1.5e5 + 10;
const int inf = 0x3f3f3f3f;
int n, q, A, maxsiz, tsiz, rt, cnt, head[maxn], v[maxn], Fa[maxn];
int dep[maxn], dis[maxn], siz[maxn], son[maxn], top[maxn], fa[maxn];
bool vis[maxn];
ll lastans;
struct Edge {
int to, nxt, val;
}e[maxn << 1];
struct Node {
int v; ll sum;
bool operator < (const Node &B) const {
return v < B.v;
}
};
vector < Node > vec[2][maxn];
template <typename T> T read(register T x = 0, register bool f = 0, register char ch = getchar()) {
for(;!isdigit(ch);ch = getchar()) f = ch == '-';
for(; isdigit(ch);ch = getchar()) x = (x << 3) + (x << 1) + (ch & 15);
return f ? -x : x ;
}
void add(rint x, rint y, rint z) {
e[++cnt] = (Edge){y, head[x], z}, head[x] = cnt;
}
void dfs1(rint x, rint prt) {
fa[x] = prt, dep[x] = dep[prt] + 1, siz[x] = 1;
for(rint i = head[x], y;i;i = e[i].nxt) {
if((y = e[i].to) == prt) continue;
dis[y] = dis[x] + e[i].val;
dfs1(y, x);
siz[x] += siz[y];
if(!son[x] || siz[y] > siz[son[x]]) son[x] = y;
}
}
void dfs2(rint x, rint tp) {
top[x] = tp;
if(son[x]) dfs2(son[x], tp);
for(rint i = head[x], y;i;i = e[i].nxt) {
if((y = e[i].to) != fa[x] && y != son[x]) dfs2(y, y);
}
}
int Dis(rint a, rint b) {
rint x = a, y = b;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
x = dep[x] < dep[y] ? x : y;
return dis[a] + dis[b] - dis[x] * 2;
}
void getrt(rint x, rint prt) {
siz[x] = 1;
rint maxs = 0;
for(rint i = head[x], y;i;i = e[i].nxt) {
if(vis[y = e[i].to] || y == prt) continue;
getrt(y, x);
siz[x] += siz[y];
maxs = max(maxs, siz[y]);
}
maxs = max(maxs, tsiz - siz[x]);
if(maxs < maxsiz) maxsiz = maxs, rt = x;
}
void dfs(rint x, rint prt, rint sum) {
siz[x] = 1;
vec[0][rt].push_back((Node){v[x], sum});
if(Fa[rt]) vec[1][rt].push_back((Node){v[x], Dis(x, Fa[rt])});
for(rint i = head[x], y;i;i = e[i].nxt) {
if(vis[y = e[i].to] || y == prt) continue;
dfs(y, x, sum + e[i].val);
siz[x] += siz[y];
}
}
void solve(rint x) {
vis[x] = 1, dfs(x, 0, 0);
for(rint i = head[x], y;i;i = e[i].nxt) {
if(vis[y = e[i].to]) continue;
tsiz = siz[y], maxsiz = inf, getrt(y, x);
Fa[rt] = x;
solve(rt);
}
}
ll query(rint opt, rint x, rint l, rint r, ll &ss) {
rint lef = lower_bound(vec[opt][x].begin(), vec[opt][x].end(), (Node){l, 0}) - vec[opt][x].begin() - 1;
rint rig = upper_bound(vec[opt][x].begin(), vec[opt][x].end(), (Node){r, 0}) - vec[opt][x].begin() - 1;
ss = rig - lef;
ll ans = 0;
if(rig >= 0 && rig < (int) vec[opt][x].size()) ans += vec[opt][x][rig].sum;
if(lef >= 0 && lef < (int) vec[opt][x].size()) ans -= vec[opt][x][lef].sum;
return ans;
}
int main() {
n = read<int>(), q = read<int>(), A = read<int>();
for(rint i = 1;i <= n; ++i) v[i] = read<int>();
for(rint i = 1, x, y, z;i < n; ++i) {
x = read<int>(), y = read<int>(), z = read<int>();
add(x, y, z), add(y, x, z);
}
dfs1(1, 0), dfs2(1, 1);
tsiz = n, maxsiz = inf, getrt(1, 0), solve(rt);
for(rint i = 1;i <= n; ++i) {
sort(vec[0][i].begin(), vec[0][i].end());
sort(vec[1][i].begin(), vec[1][i].end());
for(rint j = 1;j < (int)vec[0][i].size(); ++j) vec[0][i][j].sum += vec[0][i][j - 1].sum;
for(rint j = 1;j < (int)vec[1][i].size(); ++j) vec[1][i][j].sum += vec[1][i][j - 1].sum;
}
for(rint i = 1, x, l, r;i <= q; ++i) {
x = read<int>(), l = (read<ll>() + lastans) % A, r = (read<ll>() + lastans) % A;
l > r ? swap(l, r) : void();
ll s1, s2;
lastans = query(0, x, l, r, s1);
for(rint now = x;Fa[now];now = Fa[now]) {
lastans += query(0, Fa[now], l, r, s2) - query(1, now, l, r, s1);
lastans += (s2 - s1) * Dis(x, Fa[now]);
}
printf("%lld\n", lastans);
}
return 0;
}