CodeForces - 916E Jamie and Tree(lca,线段树,维护树换根后的子树信息)
题目大意
给你一棵以1为根的树,有三种操作:
1:把树的根改为v。
2:找到包含u和v的最小子树,给他们都加上x。
3:查询以v为根的子树的总和。
解题思路
首先先考虑第3个操作,固定1号点为根,记录一下当前的根是哪个点,询问时,如果询问的是新根,就是整棵子树的值。如果询问的点v在当前根的子树里,或者和当前的根没有父子关系,那么要查询的这个子树还是原来1号点为根的子树。如果询问的点v是当前根的祖先,那么由于从1号点换到了别的根,新根的祖先节点就变成了儿子节点,父子关系颠倒了。查询的其实是所有节点中除了以v点与当前根所在链上的直接儿子为根的子树以外的点。
第2个操作,先求出两个点与新根的lca,如果两个lca相同,就看lca(u,v)和新根的关系,如果不同,就取深度最大的那个lca(如果在子树内,lca为新根,否则在父子关系颠倒的子树上,所以取深度最大的)。然后就变成了判断一个点与新根的关系,与操作3类似。
代码
const int maxn = 2e5+10;
vector<int> e[maxn];
int n, m, val[maxn], dep[maxn], f[maxn][21];
int tim, idx[maxn], rev[maxn], son[maxn], sz[maxn];
void dfs1(int u, int p) {
sz[u] = 1;
for (auto v : e[u]) {
if (v==p) continue;
dep[v] = dep[u]+1;
f[v][0] = u;
for (int i = 1; i<21; ++i) f[v][i] = f[f[v][i-1]][i-1];
dfs1(v, u);
sz[u] += sz[v];
if (sz[v]>sz[son[u]]) son[u] = v;
}
}
int top[maxn];
void dfs2(int u, int t) {
top[u] = t; idx[u] = ++tim, rev[tim] = u;
if (!son[u]) return;
dfs2(son[u], t);
for (auto v : e[u]) {
if (v==son[u] || v==f[u][0]) continue;
dfs2(v, v);
}
}
int root = 1;
int lca(int u, int v) {
if (dep[u]<dep[v]) swap(u, v);
for (int i = 19; i>=0; --i)
if (dep[f[u][i]]>=dep[v]) u = f[u][i];
if (u==v) return u;
for (int i = 19; i>=0; --i)
if (f[u][i]!=f[v][i]) u = f[u][i], v = f[v][i];
return f[u][0];
}
int get_lca(int u, int v) {
int lu = lca(root, u);
int lv = lca(root, v);
if (lu==lv) return lca(u, v);
else return dep[lu]>dep[lv] ? lu:lv;
}
ll tr[maxn<<2], lz[maxn<<2];
inline void push_up(int rt) {
tr[rt] = tr[rt<<1]+tr[rt<<1|1];
}
inline void push_down(int rt, int len) {
if (lz[rt]) {
tr[rt<<1] += 1LL*(len+1)/2*lz[rt];
tr[rt<<1|1] += 1LL*(len/2)*lz[rt];
lz[rt<<1] += lz[rt];
lz[rt<<1|1] += lz[rt];
lz[rt] = 0;
}
}
void build(int rt, int l, int r) {
if (l==r) {
tr[rt] = val[rev[l]];
return;
}
int mid = (l+r)>>1;
build(rt<<1, l, mid);
build(rt<<1|1, mid+1, r);
push_up(rt);
}
void update(int rt, int l, int r, int L, int R, int V) {
if (l>=L && r<=R) {
tr[rt] += 1LL*(r-l+1)*V;
lz[rt] += V;
return;
}
int mid = (l+r)>>1;
push_down(rt, r-l+1);
if (L<=mid) update(rt<<1, l, mid, L, R, V);
if (R>mid) update(rt<<1|1, mid+1, r, L, R, V);
push_up(rt);
}
ll query(int rt, int l, int r, int L, int R) {
if (l>=L && r<=R) return tr[rt];
int mid = (l+r)>>1; ll sum = 0;
push_down(rt, r-l+1);
if (L<=mid) sum += query(rt<<1, l, mid, L, R);
if (R>mid) sum += query(rt<<1|1, mid+1, r, L, R);
return sum;
}
void change(int x, int V) {
if (x==root) update(1, 1, n, 1, n, V);
else if (lca(root, x)!=x) update(1, 1, n, idx[x], idx[x]+sz[x]-1, V);
else {
int y = root;
for (int i = 20; i>=0; --i)
if (dep[f[y][i]]>=dep[x]+1) y = f[y][i];
if (1<=idx[y]-1) update(1, 1, n, 1, idx[y]-1, V);
if (n>=idx[y]+sz[y]) update(1, 1, n, idx[y]+sz[y], n, V);
}
}
ll ask(int x) {
if (x==root) return tr[1];
if (lca(root, x)!=x) return query(1, 1, n, idx[x], idx[x]+sz[x]-1);
int y = root;
for (int i = 20; i>=0; --i)
if (dep[f[y][i]]>=dep[x]+1) y = f[y][i];
ll sum = 0;
if (1<=idx[y]-1) sum += query(1, 1, n, 1, idx[y]-1);
if (n>=idx[y]+sz[y]) sum += query(1, 1, n, idx[y]+sz[y], n);
return sum;
}
int main() {
IOS;
cin >> n >> m;
for (int i = 1; i<=n; ++i) cin >> val[i];
for (int i = 1, a, b; i<n; ++i) {
cin >> a >> b;
e[a].push_back(b);
e[b].push_back(a);
}
dep[1] = 1;
dfs1(1, 0);
dfs2(1, 1);
build(1, 1, n);
while(m--) {
int op; cin >> op;
if (op==1) cin >> root;
else if (op==2) {
int x, y, v; cin >> x >> y >> v;
change(get_lca(x, y), v);
}
else {
int x; cin >> x;
cout << ask(x) << endl;
}
}
return 0;
}