【UOJ 284】快乐游戏鸡(贪心)(长链剖分)(线段树)
快乐游戏鸡
题目链接:UOJ 284
题目大意
给你一棵有根数,然后多组询问,每次告诉你起点终点。
你可以从一个点走到它的儿子,然后如果你当前死亡次数小于当前点权,你就会死并回到起点。
然后边权长度都为 1,然后要你用最小的路径和走到终点,其中起点终点的点权不算,即不会在起点终点死。
思路
首先单次询问,考虑一个贪心,就是深度从小往大看,每次就选死的次数最多的。
然后由于你维护的序列是有关深度的,考虑用长链剖分。
那就是一条链维护一个线段树第 \(i\) 个位置表示距离 \(\rm top\) 深度为 \(i\) 的最大权值。
然后轻边就暴力枚举轻边连接的链的线段树。
然后发现直接搞式子很乱,考虑化简一下:
\(\sum\limits_{i=1}^d(f_{now,i}-f_{now,i-1})i\)
\(=\sum\limits_{i=1}^d(f_{now,i}i-f_{now,i-1}i)\)
\(=-f_{now,0}(0)+\sum\limits_{i=1}^{d-1}((i+1)f_{now,i}-if_{now,i})+df_{now,d}\)
\(=df_{now,d}-\sum\limits_{i=1}^{d-1}f_{now,i}\)
(补一句 \(f_{i,j}\) 是 \(i\) 的子树深度不超过 \(j\) 的最大权值)
然后我们就只需要维护 \(f_{now,i}\)。
然后你只需要线段树里面区间覆盖,区间和,求第一个大于等于某个数的位置,就还要区间最大值(用来求前面那个)。
(注意到因为是单调的我们才可以这么求)
代码
#include<cstdio>
#include<vector>
#define ll long long
using namespace std;
const int N = 3e5 + 100;
struct node {
int x, num, id;
};
int n, w[N], fa[N][21], q, deg[N], maxdeg[N], son[N], maxn[N][21], top[N], id[N], tot;
vector <int> G[N]; vector <node> qu[N];
ll ans[N];
struct XD_tree {
int n;
vector <int> lzy, maxn;
vector <ll> sum;
void Init(int m) {
lzy.resize(m << 2); maxn.resize(m << 2); sum.resize(m << 2); n = m;
}
void up(int now) {
sum[now] = sum[now << 1] + sum[now << 1 | 1];
maxn[now] = max(maxn[now << 1], maxn[now << 1 | 1]);
}
void downc(int now, int sz, int val) {
maxn[now] = val; lzy[now] = val;
sum[now] = 1ll * sz * val;
}
void down(int now, int l, int r) {
if (lzy[now]) {
int mid = (l + r) >> 1;
downc(now << 1, mid - l + 1, lzy[now]); downc(now << 1 | 1, r - mid, lzy[now]);
lzy[now] = 0;
}
}
void change(int now, int l, int r, int L, int R, int val) {
if (L > R) return ;
if (L <= l && r <= R) {
downc(now, r - l + 1, val);
return ;
}
int mid = (l + r) >> 1; down(now, l, r);
if (L <= mid) change(now << 1, l, mid, L, R, val); if (mid < R) change(now << 1 | 1, mid + 1, r, L, R, val);
up(now);
}
int find(int now, int l, int r, int val) {
if (l == r) return (maxn[now] >= val) ? l : n;
int mid = (l + r) >> 1; down(now, l, r);
if (maxn[now << 1] >= val) return find(now << 1, l, mid, val);
else return find(now << 1 | 1, mid + 1, r, val);
}
void insert(int now, int l, int r, int pl, int val) {
int R = find(now, l, r, val);
change(now, l, r, pl, R - 1, val);
}
ll query(int now, int l, int r, int L, int R) {
if (L > R) return 0;
if (L <= l && r <= R) return sum[now];
int mid = (l + r) >> 1; down(now, l, r); ll re = 0;
if (L <= mid) re += query(now << 1, l, mid, L, R); if (mid < R) re += query(now << 1 | 1, mid + 1, r, L, R);
return re;
}
}T[N];
void dfs1(int now) {
maxn[now][0] = w[fa[now][0]];
for (int i = 1; i <= 20; i++) fa[now][i] = fa[fa[now][i - 1]][i - 1], maxn[now][i] = max(maxn[now][i - 1], maxn[fa[now][i - 1]][i - 1]);
for (int i = 0; i < G[now].size(); i++) { int x = G[now][i];
deg[x] = deg[now] + 1; dfs1(x);
if (maxdeg[x] > maxdeg[son[now]]) son[now] = x;
}
maxdeg[now] = maxdeg[son[now]] + 1;
}
void dfs2(int now) {
if (top[now] == now) {
id[now] = ++tot; T[tot].Init(maxdeg[now]);
}
if (son[now]) top[son[now]] = top[now], dfs2(son[now]);
for (int i = 0; i < G[now].size(); i++) { int x = G[now][i];
if (x == son[now]) continue;
top[x] = x; dfs2(x);
}
}
int get_max(int x, int y) {
swap(x, y); int re = 0;
for (int i = 20; i >= 0; i--)
if (deg[fa[fa[x][i]][0]] >= deg[y]) re = max(re, maxn[x][i]), x = fa[x][i];
return re;
}
void dfs(int now) {
if (son[now]) dfs(son[now]);
for (int i = 0; i < G[now].size(); i++) { int x = G[now][i];
if (x == son[now]) continue; dfs(x);
for (int j = 0; j < maxdeg[x]; j++) T[id[top[now]]].insert(1, 0, maxdeg[top[now]] - 1, deg[x] - deg[top[now]] + j, T[id[x]].query(1, 0, maxdeg[x] - 1, j, j));
}
for (int i = 0; i < qu[now].size(); i++) {
int ID = qu[now][i].id, x = qu[now][i].x, num = qu[now][i].num;
int pla = T[id[top[now]]].find(1, 0, maxdeg[top[now]] - 1, num);
ans[ID] = 1ll * (pla - (deg[now] - deg[top[now]])) * num;
ans[ID] -= T[id[top[now]]].query(1, 0, maxdeg[top[now]] - 1, deg[now] - deg[top[now]] + 1, pla - 1);
ans[ID] += deg[x] - deg[now];
}
T[id[top[now]]].insert(1, 0, maxdeg[top[now]] - 1, deg[now] - deg[top[now]], w[now]);
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
for (int i = 2; i <= n; i++) {
scanf("%d", &fa[i][0]); G[fa[i][0]].push_back(i);
}
deg[0] = -1; dfs1(1); top[1] = 1; dfs2(1);
scanf("%d", &q);
for (int i = 1; i <= q; i++) {
int x, y; scanf("%d %d", &x, &y);
if (x == y || x == fa[y][0]) {ans[i] = deg[y] - deg[x]; continue;}
qu[x].push_back((node){y, get_max(x, y), i});
}
dfs(1);
for (int i = 1; i <= q; i++) printf("%lld\n", ans[i]);
return 0;
}