Luogu P2590 树的统计(树链剖分+线段树)

题意

原文很清楚了

题解

重链剖分模板题,用线段树维护即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
using std::max;
using std::swap;

const int N = 3e4 + 10, Inf = 1e9 + 7;
int n, q, c[N], x, y;
int fa[N], dep[N], son[N], siz[N];
int top[N], w[N], dfn[N], time;
int cnt, from[N], to[N << 1], nxt[N << 1];//Edges;
int maxv[N << 2], sumv[N << 2];//SegTree
inline void addEdge(int u, int v) {
    to[++cnt] = v, nxt[cnt] = from[u], from[u] = cnt;
}

void dfs1(int u) {
    siz[u] = 1, dep[u] = dep[fa[u]] + 1;
    for (int i = from[u]; i; i = nxt[i]) {
	int v = to[i]; if(v == fa[u]) continue;
	fa[v] = u, dfs1(v), siz[u] += siz[v];
	if(siz[v] > siz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int t) {
    top[u] = t, dfn[u] = ++time, w[time] = c[u];
    if(!son[u]) { return ; } dfs2(son[u], t);
    for (int i = from[u]; i; i = nxt[i]) {
	int v = to[i];
	if (v != fa[u] && v != son[u])
	    dfs2(v, v);
    }		
}

void pushup (int o, int lc, int rc) {
    sumv[o] = sumv[lc] + sumv[rc];
    maxv[o] = max(maxv[lc], maxv[rc]);
}
void build(int o = 1, int l = 1, int r = n) {
    if(l == r) { sumv[o] = maxv[o] = w[l]; return ; }
    int mid = (l + r) >> 1, lc = o << 1, rc = lc | 1;
    build(lc, l, mid), build(rc, mid + 1, r), pushup(o, lc, rc);
}
void modify(int p, int k, int o = 1, int l = 1, int r = n) {
    if(l == r && l == p) { sumv[o] = maxv[o] = k; return ; }
    int mid = (l + r) >> 1, lc = o << 1, rc = lc | 1;
    if(p <= mid) modify(p, k, lc, l, mid);
    else modify(p, k, rc, mid + 1, r);
    pushup(o, lc, rc);
}
int quemax(int ql, int qr, int o = 1, int l = 1, int r = n) {
    if(l >= ql && r <= qr) return maxv[o];
    int mid = (l + r) >> 1, lc = o << 1, rc = lc | 1, ret = -Inf;
    if(ql <= mid) ret = quemax(ql, qr, lc, l, mid);
    if(qr > mid) ret = max(ret, quemax(ql, qr, rc, mid + 1, r));
    return ret;
}
int quesum(int ql, int qr, int o = 1, int l = 1, int r = n) {
    if(l >= ql && r <= qr) return sumv[o];
    int mid = (l + r) >> 1, lc = o << 1, rc = lc | 1, ret = 0;
    if(ql <= mid) ret = quesum(ql, qr, lc, l, mid);
    if(qr > mid) ret += quesum(ql, qr, rc, mid + 1, r);
    return ret;
}

int quem(int x, int y) {
    int fx = top[x], fy = top[y], ret = -Inf;
    while(fx != fy) {
	if(dep[fx] >= dep[fy])
	    ret = max(ret, quemax(dfn[fx], dfn[x])), x = fa[fx], fx = top[x];
	else
	    ret = max(ret, quemax(dfn[fy], dfn[y])), y = fa[fy], fy = top[y];
    }
    if(dfn[x] > dfn[y]) swap(x, y);
    return max(ret, quemax(dfn[x], dfn[y]));
}
int ques(int x, int y) {
    int fx = top[x], fy = top[y], ret = 0;
    while(fx != fy) {
	if(dep[fx] >= dep[fy])
	    ret += quesum(dfn[fx], dfn[x]), x = fa[fx], fx = top[x];
	else
	    ret += quesum(dfn[fy], dfn[y]), y = fa[fy], fy = top[y];
    }
    if(dfn[x] > dfn[y]) swap(x, y);
    return ret + quesum(dfn[x], dfn[y]);
}


int main () {
    scanf("%d", &n);
    for (int i = 1, u, v; i < n; ++i) {
	scanf("%d%d", &u, &v);
	addEdge(u, v), addEdge(v, u);
    }
    for (int i = 1; i <= n; ++i) scanf("%d", &c[i]);
    dfs1(1), dfs2(1, 1);
    build();
    scanf("%d", &q);
    char opt[10];
    while(q--) {
	scanf("\n%s %d %d", opt, &x, &y);
	if(opt[0] == 'Q') {
	    if(opt[1] == 'M') printf("%d\n", quem(x, y));
	    else printf("%d\n", ques(x, y));
	} else modify(dfn[x], y);
    }
    return 0;

}
posted @ 2018-10-21 18:24  water_mi  阅读(192)  评论(0编辑  收藏  举报