AtCoder Beginner Contest 298 Ex Sum of Min of Length
挺无脑的。是不是因为 unr 所以评分虚高啊/qd
考虑把 \(L_i \to R_i\) 的路径拎出来,那么路径中点(或中边)左边的点挂的子树到 \(L_i\) 更优,右边的点挂的子树到 \(R_i\) 更优。
差分一下,可以转化成 \(O(q)\) 次询问,每次询问形如 \((u, v)\) 表示求 \(v\) 子树中所有点到 \(u\) 的距离之和。
考虑离线,把 \((u, v)\) 的询问挂到 \(u\) 子树,然后一遍 dfs。dfs 的同时维护所有点到当前点的距离,那么从父亲转移到儿子就相当于,儿子的这棵子树的距离全部 \(- 1\),子树外的距离全部 \(+ 1\)。处理挂在这个点的查询,就做一次区间查询和即可。区间加、区间查询和,使用线段树维护。
总时间复杂度 \(O((n + q) \log n)\)。
code
// Problem: Ex - Sum of Min of Length
// Contest: AtCoder - TOYOTA MOTOR CORPORATION Programming Contest 2023#1 (AtCoder Beginner Contest 298)
// URL: https://atcoder.jp/contests/abc298/tasks/abc298_h
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
const int logn = 20;
int n, m, fa[maxn], sz[maxn], son[maxn], dep[maxn], f[maxn][logn];
int top[maxn], st[maxn], ed[maxn], times, rnk[maxn];
bool vis[maxn];
ll ans[maxn];
vector<int> G[maxn];
struct node {
int l, r, k, id;
node(int a = 0, int b = 0, int c = 0, int d = 0) : l(a), r(b), k(c), id(d) {}
};
vector<node> vc[maxn];
int dfs(int u, int f, int d) {
fa[u] = f;
sz[u] = 1;
dep[u] = d;
st[u] = ++times;
rnk[times] = u;
int maxson = -1;
for (int v : G[u]) {
if (v == f) {
continue;
}
sz[u] += dfs(v, u, d + 1);
if (sz[v] > maxson) {
son[u] = v;
maxson = sz[v];
}
}
ed[u] = times;
return sz[u];
}
void dfs2(int u, int tp) {
top[u] = tp;
vis[u] = 1;
if (!son[u]) {
return;
}
dfs2(son[u], tp);
for (int v : G[u]) {
if (!vis[v]) {
dfs2(v, v);
}
}
}
inline int qlca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) {
swap(x, y);
}
x = fa[top[x]];
}
if (dep[x] > dep[y]) {
swap(x, y);
}
return x;
}
inline int qdis(int x, int y) {
return dep[x] + dep[y] - dep[qlca(x, y)] * 2;
}
inline int jump(int x, int k) {
for (int i = 18; ~i; --i) {
if (k & (1 << i)) {
x = f[x][i];
}
}
return x;
}
namespace SGT {
ll tree[maxn << 2], tag[maxn << 2];
inline void pushup(int x) {
tree[x] = tree[x << 1] + tree[x << 1 | 1];
}
inline void pushdown(int x, int l, int r) {
if (!tag[x]) {
return;
}
int mid = (l + r) >> 1;
tree[x << 1] += tag[x] * (mid - l + 1);
tree[x << 1 | 1] += tag[x] * (r - mid);
tag[x << 1] += tag[x];
tag[x << 1 | 1] += tag[x];
tag[x] = 0;
}
void build(int rt, int l, int r) {
if (l == r) {
tree[rt] = dep[rnk[l]] - 1;
return;
}
int mid = (l + r) >> 1;
build(rt << 1, l, mid);
build(rt << 1 | 1, mid + 1, r);
pushup(rt);
}
void update(int rt, int l, int r, int ql, int qr, int x) {
if (ql <= l && r <= qr) {
tree[rt] += x * (r - l + 1);
tag[rt] += x;
return;
}
pushdown(rt, l, r);
int mid = (l + r) >> 1;
if (ql <= mid) {
update(rt << 1, l, mid, ql, qr, x);
}
if (qr > mid) {
update(rt << 1 | 1, mid + 1, r, ql, qr, x);
}
pushup(rt);
}
ll query(int rt, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) {
return tree[rt];
}
pushdown(rt, l, r);
int mid = (l + r) >> 1;
ll res = 0;
if (ql <= mid) {
res += query(rt << 1, l, mid, ql, qr);
}
if (qr > mid) {
res += query(rt << 1 | 1, mid + 1, r, ql, qr);
}
return res;
}
}
void dfs3(int u, int fa) {
for (node p : vc[u]) {
ans[p.id] += SGT::query(1, 1, n, p.l, p.r) * p.k;
}
for (int v : G[u]) {
if (v == fa) {
continue;
}
SGT::update(1, 1, n, 1, n, 1);
SGT::update(1, 1, n, st[v], ed[v], -2);
dfs3(v, u);
SGT::update(1, 1, n, 1, n, -1);
SGT::update(1, 1, n, st[v], ed[v], 2);
}
}
void solve() {
scanf("%d", &n);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
scanf("%d", &m);
dfs(1, -1, 1);
dfs2(1, 1);
for (int i = 2; i <= n; ++i) {
f[i][0] = fa[i];
}
for (int j = 1; j <= 18; ++j) {
for (int i = 1; i <= n; ++i) {
f[i][j] = f[f[i][j - 1]][j - 1];
}
}
for (int i = 1, x, y; i <= m; ++i) {
scanf("%d%d", &x, &y);
if (x == y) {
vc[x].pb(1, n, 1, i);
continue;
}
if (dep[x] > dep[y]) {
swap(x, y);
}
int dis = qdis(x, y), lca = qlca(x, y);
if (lca == x) {
int u = jump(y, dis / 2);
vc[y].pb(st[u], ed[u], 1, i);
vc[x].pb(1, n, 1, i);
vc[x].pb(st[u], ed[u], -1, i);
} else {
vc[x].pb(1, n, 1, i);
int u = jump(y, dis - dis / 2 - 1);
vc[x].pb(st[u], ed[u], -1, i);
vc[y].pb(st[u], ed[u], 1, i);
}
}
SGT::build(1, 1, n);
dfs3(1, -1);
for (int i = 1; i <= m; ++i) {
printf("%lld\n", ans[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}