LG4074【WC2013】糖果公园 【树上莫队,带修莫队】
题目描述:给出一棵 \(n\) 个点的树,点有颜色 \(C_i\),长度为 \(m\) 的数组 \(V\) 和长度为 \(n\) 的数组 \(W\)。有两种操作:
-
将 \(C_x\) 修改为 \(y\)。
-
求 \(u\) 到 \(v\) 的链的 \(\sum\limits_{i=1}^m\sum\limits_{j=1}^{cnt_i}W_j\),其中 \(cnt_i\) 表示颜色 \(i\) 的出现次数。
数据范围:\(n,m\le 10^5,1\le C_i\le m\),时限6s(洛谷)或8s(UOJ)。
这是树上带修莫队的模板题。
首先我们看树分块怎么做,所以要来先做这道题。
直接讲做法了:我们是尽可能在底层把大小 \(\ge B\) 的联通块作为一块,剩下的扔给父亲合并。就是要开一个stack,维护当前还没有被分块的点。不停递归儿子,一旦已经有一个 \(\ge B\) 的连通块了,就把它们作为一块,设首都为 \(x\)(当前dfs的点)。最后把 \(x\) 放进栈中。最后递归完还要把栈中剩下的点放入最后一个块,并把首都设为 \(1\)。
inline void dfs(int x, int f){
int t = top;
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != f){
dfs(to[i], x);
if(top >= t + B){
rt[++ num] = x;
while(top > t) in[stk[top --]] = num;
}
}
stk[++ top] = x;
}
int main(){
// ...
dfs(1, 0);
if(!num) num = 1;
rt[num] = 1;
while(top) in[stk[top --]] = num;
// ...
}
我们知道没有合并,即大小 \(<B\),合并之后 \(<2B\)。就算最后一个连通块也至多有 \(<3B\),所以是比较均匀的。
然后我们看看如何从链 \((u,v)\) 变为 \((u',v')\) 并且时间复杂度为 \(O(\text{len}(u,u')+\text{len}(v,v'))\)。
首先我们改为维护 \((u,v)\) 中抠掉 \(lca\) 的答案,\(cnt\) 和每个点是否在里面的 \(vis\)。设 \(L(u,v)\) 为 \(u\) 到 \(v\) 这条链上抠掉 \(lca\) 的点集,\(\oplus\) 为集合对称差(\(A\oplus B=(A\cup B)-(A\cap B)\)),\(S(u)\) 为 \(1\) 到 \(u\) 的这条链的点集。则 \(L(u,v)=S(u)\oplus S(v)\),且集合对称差肯定是有交换、结合律的。
于是就直接是将 \(L(u,u')\) 和 \(L(v,v')\) 全部 \(vis\) 取反就是 \(L(u',v')\),然后把 \(lca\) 取反就是 \(u'\) 到 \(v'\) 这条链。
至于带修莫队怎么做,就去看这道题。取 \(B\) 比 \(n^\frac{2}{3}\) 少一点点就可以,时间复杂度 \(O(n^\frac{5}{3}+n\log n)\)。
#include<bits/stdc++.h>
#define Rint register int
using namespace std;
typedef long long LL;
const int N = 100003;
int n, m, B, q, ql, qr, qnow, qnum, cnum, V[N], W[N], C[N], head[N], to[N << 1], nxt[N << 1], dfn[N], cnt[N];
bool vis[N];
LL ans[N], qans;
inline void add(int a, int b){
static int cnt = 0;
to[++ cnt] = b; nxt[cnt] = head[a]; head[a] = cnt;
}
int dep[N], top[N], fa[N], siz[N], wson[N], stk[N], tp, bnum, bel[N];
inline void dfs1(int x){
int tmp = tp;
siz[x] = 1;
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != fa[x]){
fa[to[i]] = x; dep[to[i]] = dep[x] + 1;
dfs1(to[i]);
siz[x] += siz[to[i]];
if(siz[to[i]] > siz[wson[x]]) wson[x] = to[i];
if(tp >= tmp + B){
++ bnum;
while(tp > tmp) bel[stk[tp --]] = bnum;
}
}
stk[++ tp] = x;
}
inline void dfs2(int x, int topf){
top[x] = topf;
if(wson[x]) dfs2(wson[x], topf);
for(Rint i = head[x];i;i = nxt[i])
if(to[i] != wson[x] && to[i] != fa[x])
dfs2(to[i], to[i]);
}
inline int lca(int u, int v){
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
struct Query {
int u, v, id, tim;
inline bool operator < (const Query &o) const {
if(bel[u] != bel[o.u]) return bel[u] < bel[o.u];
if(bel[v] != bel[o.v]) return bel[v] < bel[o.v];
return tim < o.tim;
}
} que[N];
struct Change {
int u, val;
} cha[N];
inline void work(int x){
if(vis[x]) qans -= (LL) V[C[x]] * W[cnt[C[x]]], -- cnt[C[x]];
else ++ cnt[C[x]], qans += (LL) V[C[x]] * W[cnt[C[x]]];
vis[x] ^= 1;
}
inline void workpath(int u, int v){
if(dep[u] < dep[v]) swap(u, v);
while(dep[u] > dep[v]){work(u); u = fa[u];}
while(u != v){work(u); u = fa[u]; work(v); v = fa[v];}
}
inline void change(int i){
int u = cha[i].u;
if(vis[u]){
work(u); swap(C[u], cha[i].val); work(u);
} else swap(C[u], cha[i].val);
}
int main(){
scanf("%d%d%d", &n, &m, &q); B = pow(n, 2.0 / 3);
for(Rint i = 1;i <= m;i ++) scanf("%d", V + i);
for(Rint i = 1;i <= n;i ++) scanf("%d", W + i);
for(Rint i = 1;i < n;i ++){
int a, b;
scanf("%d%d", &a, &b);
add(a, b); add(b, a);
}
dfs1(1);
if(!bnum) bnum = 1;
while(tp) bel[stk[tp --]] = bnum;
dfs2(1, 1);
for(Rint i = 1;i <= n;i ++) scanf("%d", C + i);
for(Rint i = 1;i <= q;i ++){
int opt, x, y;
scanf("%d%d%d", &opt, &x, &y);
if(opt == 0) cha[++ cnum] = (Change){x, y};
else ++ qnum, que[qnum] = (Query){x, y, qnum, cnum};
}
sort(que + 1, que + qnum + 1);
ql = qr = 1; qnow = 0;
for(Rint i = 1;i <= qnum;i ++){
int tl = que[i].u, tr = que[i].v;
workpath(ql, tl); workpath(qr, tr);
ql = tl; qr = tr;
while(qnow < que[i].tim) change(++ qnow);
while(qnow > que[i].tim) change(qnow --);
int LCA = lca(tl, tr);
work(LCA); ans[que[i].id] = qans; work(LCA);
}
for(Rint i = 1;i <= qnum;i ++) printf("%lld\n", ans[i]);
}