【数据结构】重链剖分
维护点权信息,记得涉及线段树的操作都要套一个tid变换成线段树上的坐标。
struct TreeChain {
static const int MAXN = 100000 + 10;
struct Edge {
struct EdgeNode {
int v, nxt;
} en[MAXN << 1];
int h[MAXN], top;
void Init(int n) {
top = 0;
memset(h, 0, sizeof(h[0]) * (n + 1));
}
void Add(int u, int v) {
en[++top] = {v, h[u]}, h[u] = top;
en[++top] = {u, h[v]}, h[v] = top;
}
} edge;
int dep[MAXN], siz[MAXN], mch[MAXN], pat[MAXN];
int top[MAXN], tid[MAXN], dit[MAXN], cnt;
int a[MAXN];
void Init(int r) {
edge.Init(n);
for(int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
for(int i = 1, u, v; i <= n - 1; ++i) {
scanf("%d%d", &u, &v);
edge.Add(u, v);
}
cnt = 0;
dfs1(r, 0);
dfs2(r, r);
st.Build(1, 1, n);
for(int i = 1; i <= n; ++i)
st.Update(1, 1, n, tid[i], tid[i], a[i]);
}
void dfs1(int u, int p) {
dep[u] = dep[p] + 1, siz[u] = 1, mch[u] = 0, pat[u] = p;
for(int i = edge.h[u]; i; i = edge.en[i].nxt) {
int v = edge.en[i].v;
if(v == p)
continue;
dfs1(v, u);
siz[u] += siz[v];
if(mch[u] == 0 || siz[v] > siz[mch[u]])
mch[u] = v;
}
}
void dfs2(int u, int t) {
top[u] = t, tid[u] = ++cnt, dit[cnt] = u;
if(mch[u] != 0)
dfs2(mch[u], t);
for(int i = edge.h[u]; i; i = edge.en[i].nxt) {
int v = edge.en[i].v;
if(v == pat[u] || v == mch[u])
continue;
dfs2(v, v);
}
}
int lca(int u, int v) {
for(int tu = top[u], tv = top[v]; tu != tv; u = pat[tu], tu = top[u]) {
if(dep[tu] < dep[tv])
swap(u, v), swap(tu, tv);
}
return (dep[u] <= dep[v]) ? u : v;
}
ll QueryChain(int u, int v) {
ll res = 0;
for(int tu = top[u], tv = top[v]; tu != tv; u = pat[tu], tu = top[u]) {
if(dep[tu] < dep[tv])
swap(u, v), swap(tu, tv);
res += st.Query(1, 1, n, tid[tu], tid[u]);
}
if(tid[u] > tid[v])
swap(u, v);
res += st.Query(1, 1, n, tid[u], tid[v]);
return res % mod;
}
ll QuerySubtree(int u) {
return st.Query(1, 1, n, tid[u], tid[u] + siz[u] - 1);
}
void UpdateChain(int u, int v, int val) {
for(int tu = top[u], tv = top[v]; tu != tv; u = pat[tu], tu = top[u]) {
if(dep[tu] < dep[tv])
swap(u, v), swap(tu, tv);
st.Update(1, 1, n, tid[tu], tid[u], val);
}
if(tid[u] > tid[v])
swap(u, v);
st.Update(1, 1, n, tid[u], tid[v], val);
}
void UpdateSubtree(int u, int val) {
st.Update(1, 1, n, tid[u], tid[u] + siz[u] - 1, val);
}
} tc;
维护边权信息
定根之后,每条边和深度更大的端点构成双射,用这个点代表这条边,在更新的时候注意边界问题。
ll QueryChain(int u, int v) {
ll res = 0;
for(int tu = top[u], tv = top[v]; tu != tv; u = pat[tu], tu = top[u]) {
if(dep[tu] < dep[tv])
swap(u, v), swap(tu, tv);
res += st.Query(1, 1, n, tid[tu], tid[u]);
}
if(tid[u] == tid[v])
return;
if(tid[u] > tid[v])
swap(u, v);
res += st.Query(1, 1, n, tid[u] + 1, tid[v]);
return res % mod;
}
在未到达同一条重链时,查询为tid[tu],tid[u],表示从u节点一直计算到轻链的顶部,并把轻重链切换的边也算上。
在到达同一条重链后,注意深度小的那个点不应该更新,这里做一个越界判断。
树状数组单点修改版本:
struct BinaryIndexTree {
int n;
int sm[MAXN];
void Add(int x, int v) {
for(int i = x; i <= n; i += i & (-i))
sm[i] += v;
}
int Sum(int x) {
int res = 0;
for(int i = x; i; i -= i & (-i))
res += sm[i];
return res;
}
int QuerySum(int x, int y) {
return Sum(y) - Sum(x - 1);
}
void Init(int _n) {
n = _n;
memset(sm, 0, sizeof(sm[0]) * (n + 1));
}
} bit;
struct TreeChain {
struct Edge {
struct EdgeNode {
int v, nxt;
} e[MAXN << 1];
int h[MAXN], top;
void Init(int n) {
top = 0;
memset(h, 0, sizeof(h[0]) * (n + 1));
}
void Add(int u, int v) {
e[++top] = {v, h[u]}, h[u] = top;
e[++top] = {u, h[v]}, h[v] = top;
}
} edge;
int dep[MAXN], siz[MAXN], mch[MAXN], pa[MAXN];
int top[MAXN], tid[MAXN], dit[MAXN], cnt;
void Init1(int n) {
edge.Init(n);
}
void Init2(int n, int r) {
cnt = 0;
dfs1(r, 0);
dfs2(r, r);
bit.Init(n);
}
void dfs1(int u, int p) {
dep[u] = dep[p] + 1, siz[u] = 1, mch[u] = 0, pa[u] = p;
for(int i = edge.h[u]; i; i = edge.e[i].nxt) {
int v = edge.e[i].v;
if(v == p)
continue;
dfs1(v, u);
siz[u] += siz[v];
if(mch[u] == 0 || siz[v] > siz[mch[u]])
mch[u] = v;
}
}
void dfs2(int u, int t) {
top[u] = t, tid[u] = ++cnt, dit[cnt] = u;
if(mch[u] != 0)
dfs2(mch[u], t);
for(int i = edge.h[u]; i; i = edge.e[i].nxt) {
int v = edge.e[i].v;
if(v == pa[u] || v == mch[u])
continue;
dfs2(v, v);
}
}
int BinarySearch(int u) {
// 找u节点的祖先(含u本身)中第一个非0节点
while(u > 1) {
if(bit.QuerySum(tid[top[u]], tid[u]) == 0)
u = pa[top[u]];
else
break;
}
if(u == 0 || bit.QuerySum(tid[top[u]], tid[u]) == 0)
return 0;
int L = tid[top[u]], R = tid[u];
while(L < R) {
int M = (L + R + 1) >> 1;
if(bit.QuerySum(M, R) == 0)
R = M - 1;
else
L = M;
}
return dit[L];
}
int QueryVertex(int u) {
return bit.QuerySum(tid[u], tid[u]);
}
int QueryChain(int u, int v) {
int res = 0;
for(int tu = top[u], tv = top[v]; tu != tv; u = pa[tu], tu = top[u]) {
if(dep[tu] < dep[tv])
swap(u, v), swap(tu, tv);
res += bit.QuerySum(tid[tu], tid[u]);
}
if(tid[u] > tid[v])
swap(u, v);
res += bit.QuerySum(tid[u], tid[v]);
return res;
}
int QuerySubtree(int u) {
return bit.QuerySum(tid[u], tid[u] + siz[u] - 1);
}
void UpdateVertex(int u, int val) {
bit.Add(tid[u], val);
}
} tc;