【题解】CF1178G The Awesomest Vertex
题意
给定一棵大小为 \(n\) 的树以及 \(m\) 个操作,
定义一个结点 \(u\) 的权值为:
\(|\sum\limits_{w \in R(v)} a_w| \cdot |\sum\limits_{w \in R(v)} b_w|\)
其中 \(R(v)\) 表示结点 \(v\) 的祖先结点(含自身)
每次操作可以:
-
将结点 \(u\) 的权值加 \(x\)
-
询问结点 \(u\) 的子树内最大权值
\(n, m \leq 2 \times 10^5\)
思路
分块 + 凸包。
首先子树问题考虑 dfs 序拍平成序列问题,发现 \(|\sum\limits_{w \in R(v)} b_w|\) 可以直接预处理出来,是一个常数。
那么每次操作的影响形如 \(|\sum\limits_{w \in R(v)} a_w + x| \cdot b(v)\),将 \(b(v)\) 看成斜率,那么这里就是一条直线的形式。
绝对值很烦,考虑变成 \(\max(a_w + x, -a_w - x)\) 分别维护,最后取较大值即可。
于是问题变成在序列上维护若干条直线,每次问在特定位置的最高点坐标,是凸包套路。
但是,在序列上维护区间凸包?显然线段树不可做,考虑万用分块。
每次修改,对于整块中的直线整体上移,所以只需要考虑给它打求和标记。对于散块可以暴力重构。
每次询问,对于整块可以直接在凸包上二分求最高点坐标,对于散块可以暴力查询。
发现这样做的复杂度带一只 \(\log\),考虑优化。
首先 \(x\) 坐标单调,可以直接指针代替二分。
每次重构会将直线按斜率升序加入单调栈,这里可以初始时按斜率排好序。
时间复杂度 \(O(n \sqrt{n})\),magic!
代码
#include <cstdio>
#include <cmath>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 4e5 + 5;
int n, m;
int cnt, bs = 10;
int a[maxn], b[maxn], inp[maxn], outp[maxn];
vector<int> g[maxn];
struct line
{
int k;
ll b;
inline ll gety(int x) { return 1ll * k * x + b; }
inline void update(int x) { b += 1ll * k * x; }
} ;
double interx(const line& a, const line& b)
{
if (a.k == b.k) return (a.b > b.b ? 1e10 : -1e10);
return (double)(b.b - a.b) / (a.k - b.k);
}
struct block
{
line q[305], *a;
int x, pos, top;
double k[305];
ll query()
{
while ((pos < top) && (k[pos + 1] < x)) pos++;
return q[pos].gety(x);
}
void build()
{
if (x) for (int i = 0; i < bs; i++) a[i].update(x);
x = top = 0, pos = 1;
for (int i = 0; i < bs; i++)
{
while ((top > 1) && (k[top] > interx(q[top], a[i]))) top--;
q[++top] = a[i];
k[top] = interx(q[top - 1], q[top]);
}
}
} ;
struct data
{
line x;
int p;
bool operator < (const data& rhs) const { return (x.k == rhs.x.k ? x.b < rhs.x.b : x.k < rhs.x.k); }
} seq[305];
struct item
{
block b[705];
line a[maxn];
int st[maxn];
void build()
{
for (int i = 0; i <= n; i += bs)
{
b[i / bs].a = a + i;
for (int j = 0; j < bs; j++) seq[j] = (data){a[j + i], j + i};
sort(seq, seq + bs);
for (int j = 0; j < bs; j++)
{
a[j + i] = seq[j].x;
st[seq[j].p] = j + i;
}
b[i / bs].build();
}
}
ll query(int l, int r)
{
int bl = l / bs, br = r / bs;
ll ans = -(1ll << 60);
if (bl == br) for (int i = l; i <= r; i++) ans = max(ans, a[st[i]].gety(b[bl].x));
else
{
bl++;
for (int i = l; i < bl * bs; i++) ans = max(ans, a[st[i]].gety(b[bl - 1].x));
for (int i = br * bs; i <= r; i++) ans = max(ans, a[st[i]].gety(b[br].x));
for (int i = bl; i < br; i++) ans = max(ans, b[i].query());
}
return ans;
}
void update(int l, int r, int x)
{
int bl = l / bs, br = r / bs;
if (bl == br)
{
for (int i = l; i <= r; i++) a[st[i]].update(x);
b[bl].build();
}
else
{
bl++;
for (int i = l; i < bl * bs; i++) a[st[i]].update(x);
b[bl - 1].build();
for (int i = br * bs; i <= r; i++) a[st[i]].update(x);
b[br].build();
for (int i = bl; i < br; i++) b[i].x += x;
}
}
} s1, s2;
inline int read()
{
int res = 0, flag = 1;
char ch = getchar();
while ((ch < '0') || (ch > '9'))
{
if (ch == '-') flag = -1;
ch = getchar();
}
while ((ch >= '0') && (ch <= '9'))
{
res = res * 10 + ch - '0';
ch = getchar();
}
return res * flag;
}
inline void write(ll x)
{
if (x < 0) putchar('-'), x = -x;
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
void dfs(int u)
{
inp[u] = ++cnt;
for (int v : g[u])
{
a[v] += a[u], b[v] += b[u];
dfs(v);
}
outp[u] = cnt;
}
int main()
{
scanf("%d%d", &n, &m);
while (bs * bs * 12 <= n * 5) bs++;
for (int i = 2, f; i <= n; i++)
{
scanf("%d", &f);
g[f].push_back(i);
}
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
dfs(1);
s1.a[0] = s2.a[0] = (line){0, -(1ll << 60)};
for (int i = 1; i <= n; i++)
{
if (b[i] < 0) b[i] = -b[i];
s1.a[inp[i]] = (line){b[i], 1ll * a[i] * b[i]};
s2.a[inp[i]] = (line){-b[i], -1ll * a[i] * b[i]};
}
for (int i = n + 1; i < n / bs * bs + bs; i++) s1.a[i] = s2.a[i] = (line){0, -(1ll << 60)};
s1.build(), s2.build();
for (int i = 1, opt, u, x; i <= m; i++)
{
scanf("%d", &opt);
if (opt == 1)
{
scanf("%d%d", &u, &x);
s1.update(inp[u], outp[u], x);
s2.update(inp[u], outp[u], x);
}
else
{
scanf("%d", &u);
printf("%lld\n", max(s1.query(inp[u], outp[u]), s2.query(inp[u], outp[u])));
}
}
return 0;
}