AtCoder Beginner Contest 298 Ex Sum of Min of Length

洛谷传送门

AtCoder 传送门

挺无脑的。是不是因为 unr 所以评分虚高啊/qd

考虑把 \(L_i \to R_i\) 的路径拎出来,那么路径中点(或中边)左边的点挂的子树到 \(L_i\) 更优,右边的点挂的子树到 \(R_i\) 更优。

差分一下,可以转化成 \(O(q)\) 次询问,每次询问形如 \((u, v)\) 表示求 \(v\) 子树中所有点到 \(u\) 的距离之和。

考虑离线,把 \((u, v)\) 的询问挂到 \(u\) 子树,然后一遍 dfs。dfs 的同时维护所有点到当前点的距离,那么从父亲转移到儿子就相当于,儿子的这棵子树的距离全部 \(- 1\),子树外的距离全部 \(+ 1\)。处理挂在这个点的查询,就做一次区间查询和即可。区间加、区间查询和,使用线段树维护。

总时间复杂度 \(O((n + q) \log n)\)

code
// Problem: Ex - Sum of Min of Length
// Contest: AtCoder - TOYOTA MOTOR CORPORATION Programming Contest 2023#1 (AtCoder Beginner Contest 298)
// URL: https://atcoder.jp/contests/abc298/tasks/abc298_h
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const int logn = 20;

int n, m, fa[maxn], sz[maxn], son[maxn], dep[maxn], f[maxn][logn];
int top[maxn], st[maxn], ed[maxn], times, rnk[maxn];
bool vis[maxn];
ll ans[maxn];
vector<int> G[maxn];

struct node {
	int l, r, k, id;
	node(int a = 0, int b = 0, int c = 0, int d = 0) : l(a), r(b), k(c), id(d) {}
};

vector<node> vc[maxn];

int dfs(int u, int f, int d) {
	fa[u] = f;
	sz[u] = 1;
	dep[u] = d;
	st[u] = ++times;
	rnk[times] = u;
	int maxson = -1;
	for (int v : G[u]) {
		if (v == f) {
			continue;
		}
		sz[u] += dfs(v, u, d + 1);
		if (sz[v] > maxson) {
			son[u] = v;
			maxson = sz[v];
		}
	}
	ed[u] = times;
	return sz[u];
}

void dfs2(int u, int tp) {
	top[u] = tp;
	vis[u] = 1;
	if (!son[u]) {
		return;
	}
	dfs2(son[u], tp);
	for (int v : G[u]) {
		if (!vis[v]) {
			dfs2(v, v);
		}
	}
}

inline int qlca(int x, int y) {
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]]) {
			swap(x, y);
		}
		x = fa[top[x]];
	}
	if (dep[x] > dep[y]) {
		swap(x, y);
	}
	return x;
}

inline int qdis(int x, int y) {
	return dep[x] + dep[y] - dep[qlca(x, y)] * 2;
}

inline int jump(int x, int k) {
	for (int i = 18; ~i; --i) {
		if (k & (1 << i)) {
			x = f[x][i];
		}
	}
	return x;
}

namespace SGT {
	ll tree[maxn << 2], tag[maxn << 2];
	
	inline void pushup(int x) {
		tree[x] = tree[x << 1] + tree[x << 1 | 1];
	}
	
	inline void pushdown(int x, int l, int r) {
		if (!tag[x]) {
			return;
		}
		int mid = (l + r) >> 1;
		tree[x << 1] += tag[x] * (mid - l + 1);
		tree[x << 1 | 1] += tag[x] * (r - mid);
		tag[x << 1] += tag[x];
		tag[x << 1 | 1] += tag[x];
		tag[x] = 0;
	}
	
	void build(int rt, int l, int r) {
		if (l == r) {
			tree[rt] = dep[rnk[l]] - 1;
			return;
		}
		int mid = (l + r) >> 1;
		build(rt << 1, l, mid);
		build(rt << 1 | 1, mid + 1, r);
		pushup(rt);
	}
	
	void update(int rt, int l, int r, int ql, int qr, int x) {
		if (ql <= l && r <= qr) {
			tree[rt] += x * (r - l + 1);
			tag[rt] += x;
			return;
		}
		pushdown(rt, l, r);
		int mid = (l + r) >> 1;
		if (ql <= mid) {
			update(rt << 1, l, mid, ql, qr, x);
		}
		if (qr > mid) {
			update(rt << 1 | 1, mid + 1, r, ql, qr, x);
		}
		pushup(rt);
	}
	
	ll query(int rt, int l, int r, int ql, int qr) {
		if (ql <= l && r <= qr) {
			return tree[rt];
		}
		pushdown(rt, l, r);
		int mid = (l + r) >> 1;
		ll res = 0;
		if (ql <= mid) {
			res += query(rt << 1, l, mid, ql, qr);
		}
		if (qr > mid) {
			res += query(rt << 1 | 1, mid + 1, r, ql, qr);
		}
		return res;
	}
}

void dfs3(int u, int fa) {
	for (node p : vc[u]) {
		ans[p.id] += SGT::query(1, 1, n, p.l, p.r) * p.k;
	}
	for (int v : G[u]) {
		if (v == fa) {
			continue;
		}
		SGT::update(1, 1, n, 1, n, 1);
		SGT::update(1, 1, n, st[v], ed[v], -2);
		dfs3(v, u);
		SGT::update(1, 1, n, 1, n, -1);
		SGT::update(1, 1, n, st[v], ed[v], 2);
	}
}

void solve() {
	scanf("%d", &n);
	for (int i = 1, u, v; i < n; ++i) {
		scanf("%d%d", &u, &v);
		G[u].pb(v);
		G[v].pb(u);
	}
	scanf("%d", &m);
	dfs(1, -1, 1);
	dfs2(1, 1);
	for (int i = 2; i <= n; ++i) {
		f[i][0] = fa[i];
	}
	for (int j = 1; j <= 18; ++j) {
		for (int i = 1; i <= n; ++i) {
			f[i][j] = f[f[i][j - 1]][j - 1];
		}
	}
	for (int i = 1, x, y; i <= m; ++i) {
		scanf("%d%d", &x, &y);
		if (x == y) {
			vc[x].pb(1, n, 1, i);
			continue;
		}
		if (dep[x] > dep[y]) {
			swap(x, y);
		}
		int dis = qdis(x, y), lca = qlca(x, y);
		if (lca == x) {
			int u = jump(y, dis / 2);
			vc[y].pb(st[u], ed[u], 1, i);
			vc[x].pb(1, n, 1, i);
			vc[x].pb(st[u], ed[u], -1, i);
		} else {
			vc[x].pb(1, n, 1, i);
			int u = jump(y, dis - dis / 2 - 1);
			vc[x].pb(st[u], ed[u], -1, i);
			vc[y].pb(st[u], ed[u], 1, i);
		}
	}
	SGT::build(1, 1, n);
	dfs3(1, -1);
	for (int i = 1; i <= m; ++i) {
		printf("%lld\n", ans[i]);
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2023-06-16 14:08  zltzlt  阅读(7)  评论(0编辑  收藏  举报