luogu4074 [WC2013]糖果公园(树上带修莫队)
题目大意:给一个树,树上每个点都有一种颜色,每个颜色都有一个收益
每次修改一个点上的颜色
或询问一条链上所有颜色第i次遇到颜色j可以获得w[i]*v[j]的价值,求链上价值和
题解:树上带修莫队
按照带修莫队那套理论,我们要把树分成若干个块满足每个块内联通
把询问按照左端点所在块为第一关键字,右端点所在块为第二关键字,时间为第三关键字排序
假设块的大小为 \(B\)
则换块次数为 \(O((\frac{N}{B})^2)\) ,每次换块复杂度 \(O(n)\)
不换块每块内时间维度修改复杂度为 \(O(q)\),总共为 $ O((\frac{N}{B})^2q)$
每次询问块内改节点是 \(O(B)\),所以总复杂度(如果默认n,q同阶) 为 \(O(NB+\frac{N^3}{B^2})\),发现取 \(B=N^{2/3}\) 时值最小,为 \(O(N^{5/3})\)
我们每次维护当前 x 到 rt 路径上点 和 y 到 rt 路径上点 的集合对称差(差不多就是异或)(开一个bool数组记录某个点是否被维护)(说白了就是维护x到y路径除了lca的所有点)
从 (x,y) 到 (nx,ny) 转移时,我们只需要改变x,y,nx,ny到rt的值,其实就是x到nx路径上除了lca的所有点以及y到ny路径上除了lca以外的所有点(比较方便维护)
求答案时候,我们临时把lca的值加进去,最后再减回来就行
注意w前缀和要开long long,否则会炸
#include <cmath>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
struct fuck { int x, y, id; } a[200010];
int n, m, q, sz, top, tot, bcnt;
int v[200010], bid[200010], bucket[200010];
int fa[200010][19], s[200010], c[200010], tmp[200010];
vector<int> out[200010];
int pos[200010], from[200010], to[200010], depth[200010];
bool in[200010];
long long sum, ans[200010], w[200010];
bool fucked(const fuck &a, const fuck &b)
{
if (bid[a.x] != bid[b.x]) return bid[a.x] < bid[b.x];
if (bid[a.y] != bid[b.y]) return bid[a.y] < bid[b.y];
return a.id < b.id;
}
void dfs(int x)
{
int tmp = top;
for (int i : out[x]) if (fa[x][0] != i)
{
fa[i][0] = x, depth[i] = depth[x] + 1, dfs(i);
if (top - tmp >= sz) { bcnt++; while (top != tmp) bid[s[top--]] = bcnt; }
}
s[++top] = x;
}
int lca(int x, int y)
{
if (depth[x] < depth[y]) swap(x, y);
for (int i = 18; i >= 0; i--) if (depth[fa[x][i]] >= depth[y]) x = fa[x][i];
if (x == y) return x;
for (int i = 18; i >= 0; i--) if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
void del(int x) { sum -= v[x] * w[bucket[x]]; bucket[x]--; sum += v[x] * w[bucket[x]]; }
void add(int x) { sum -= v[x] * w[bucket[x]]; bucket[x]++; sum += v[x] * w[bucket[x]]; }
void chenge(int x) { if (in[x]) del(c[x]); else add(c[x]); in[x] ^= 1; }
void chenge(int x, int y)
{
int lc = lca(x, y);
while (x != lc) chenge(x), x = fa[x][0];
while (y != lc) chenge(y), y = fa[y][0];
}
int main()
{
scanf("%d%d%d", &n, &m, &q), sz = cbrt(n), sz = sz * sz;
for (int i = 1; i <= m; i++) scanf("%d", &v[i]);
for (int i = 1; i <= n; i++) scanf("%lld", &w[i]), w[i] += w[i - 1];
for (int x, y, i = 1; i < n; i++) scanf("%d%d", &x, &y), out[x].push_back(y), out[y].push_back(x);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]), tmp[i] = c[i];
depth[1] = 1, dfs(1); if (top == 0) { top++; } while (top) bid[s[top--]] = bcnt;
for (int j = 1; j <= 18; j++) for (int i = 1; i <= n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
for (int opd, x, y, i = 1; i <= q; i++)
{
scanf("%d%d%d", &opd, &x, &y);
if (opd) { if (bid[x] > bid[y]) { swap(x, y); } a[++tot] = (fuck){x, y, i}; }
else pos[i] = x, from[i] = tmp[x], tmp[x] = y, to[i] = y;
}
sort(a + 1, a + 1 + tot, fucked);
for (int x = 1, y = 1, t = 0, i = 1; i <= tot; i++)
{
chenge(x, a[i].x), chenge(y, a[i].y), x = a[i].x, y = a[i].y;
for (; t < a[i].id; t++) if (pos[t]) { c[pos[t]] = to[t]; if (in[pos[t]]) del(from[t]), add(to[t]); }
for (; t > a[i].id; t--) if (pos[t]) { c[pos[t]] = from[t]; if (in[pos[t]]) del(to[t]), add(from[t]); }
add(c[lca(x, y)]), ans[a[i].id] = sum, del(c[lca(x, y)]);
}
for (int i = 1; i <= q; i++) if (pos[i] == 0) printf("%lld\n", ans[i]);
return 0;
}