P3401 洛谷树 题解

P3401

考虑先将路径权值进行转化,因为很难对路径直接进行统计。考虑如何表示出这条路径的权值。记 \(s_i = \oplus_{j \in \text{path}(1, i)} w_j\),其中 \(\text{path}(i, j)\) 表示 \(i\)\(j\) 的路径上的边集。则 \(u\to v\) 的路径的权值为 \(s_u\oplus s_v\)

现在转变为了求两点路径上任选两点的权值和,以及区间异或。区间异或是因为单次修改一条边的边权会导致其子树内所有点的 \(s\) 值改变,即在 \(u\) 子树内的点都会异或上 \(x\oplus s_u\)\(x\) 指要修改为的值。

不好直接维护,但是看到异或,可以考虑拆位,这样统计和修改都会很简单。统计就是对每一位找到路径中有多少个点的 \(s\) 值的第 \(i\) 位为 \(1\),记为 \(c\),若路径上共有 \(p\) 个点,则答案为 \(2^i\times c\times(p - c)\)。然后修改操作变为区间取反,直接上线段树即可。

时间复杂度:\(\mathcal{O}(n\log^2n\log V)\)
空间复杂度:\(\mathcal{O}(n\log V)\)

代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector<int>
#define eb emplace_back
#define pii pair<int, ll>
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
bool Mbe;
mt19937_64 rng(35);
constexpr int N = 3e5 + 10;
int n, m, dfc;
int s[N], vn[N];
int sz[N], dep[N], hv[N], fa[N], top[N], dfn[N], rdfn[N];
int head[N], cnt_e;
struct edge {
	int to, w, nxt;
} e[N << 1];
void adde(int u, int v, int w) {
	++cnt_e, e[cnt_e].to = v, e[cnt_e].w = w, e[cnt_e].nxt = head[u], head[u] = cnt_e;
}
void dfs1(int u, int ff) {
	fa[u] = ff, dep[u] = dep[ff] + 1, sz[u] = 1;
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == ff) continue;
		s[v] = e[i].w ^ s[u];
		vn[v] = e[i].w;
		dfs1(v, u);
		sz[u] += sz[v];
		if(sz[v] > sz[hv[u]]) hv[u] = v;
	}
}
void dfs2(int u, int f) {
	top[u] = f, rdfn[dfn[u] = ++dfc] = u;
	if(!hv[u]) return;
	dfs2(hv[u], f);
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == hv[u] || v == fa[u]) continue;
		dfs2(v, v);
	}
}
vi cnt(10), tmp(10);
int val[N << 2][10], laz[N << 2][10];
void tag(int x, int y, int L, int R) {
	val[x][y] = (R - L + 1) - val[x][y];
	laz[x][y] ^= 1;
}
void down(int x, int y, int L, int R) {
	if(laz[x][y]) {
		int m = (L + R) >> 1;
		tag(x << 1, y, L, m);
		tag(x << 1 | 1, y, m + 1, R);
		laz[x][y] = 0;
	}
}
void build(int x, int L, int R) {
	if(L == R) {
		for(int i = 0; i < 10; ++i) val[x][i] = s[rdfn[L]] >> i & 1;
		return;
	}
	int m = (L + R) >> 1;
	build(x << 1, L, m);
	build(x << 1 | 1, m + 1, R);
	for(int i = 0; i < 10; ++i) val[x][i] = val[x << 1][i] + val[x << 1 | 1][i];
}
void modify(int x, int L, int R, int l, int r, int v) {
	if(l <= L && R <= r) {
		for(int i = 0; i < 10; ++i)
			if(v >> i & 1)
				tag(x, i, L, R);
		return;
	}
	for(int i = 0; i < 10; ++i) down(x, i, L, R);
	int m = (L + R) >> 1;
	if(l <= m) modify(x << 1, L, m, l, r, v);
	if(r > m) modify(x << 1 | 1, m + 1, R, l, r, v);
	for(int i = 0; i < 10; ++i)
		val[x][i] = val[x << 1][i] + val[x << 1 | 1][i];
}
void query(int x, int L, int R, int l, int r) {
	if(l <= L && R <= r) {
		for(int i = 0; i < 10; ++i) cnt[i] += val[x][i];
		return;
	}
	for(int i = 0; i < 10; ++i) down(x, i, L, R);
	int m = (L + R) >> 1;
	if(l <= m) query(x << 1, L, m, l, r);
	if(r > m) query(x << 1 | 1, m + 1, R, l, r);
}
ll qry(int u, int v) {
	int au = u, av = v;
	for(int i = 0; i < 10; ++i) cnt[i] = 0;
	while(top[u] ^ top[v]) {
		if(dep[top[u]] < dep[top[v]]) swap(u, v);
		query(1, 1, n, dfn[top[u]], dfn[u]);
		u = fa[top[u]];
	}
	if(dfn[u] < dfn[v]) swap(u, v);
	query(1, 1, n, dfn[v], dfn[u]);
	ll ans = 0;
	for(int i = 0; i < 10; ++i) {
		int c1 = cnt[i], c2 = dep[au] + dep[av] - 2 * dep[v] - cnt[i] + 1;
		ans += (1ll << i) * c1 * c2;
	}
	return ans;
}
bool Med;
int main() {
	fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
//	freopen("P3401_1.in", "r", stdin);
//	freopen("data.out", "w", stdout);
	ios :: sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	cin >> n >> m;
	for(int i = 1; i < n; ++i) {
		int u, v, w;
		cin >> u >> v >> w;
		adde(u, v, w);
		adde(v, u, w);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	build(1, 1, n);
	for(int i = 1; i <= m; ++i) {
		int opt, x, y, z;
		cin >> opt >> x >> y;
		if(opt == 1) {
			cout << qry(x, y) << "\n";	
		} else {
			cin >> z;
			if(dep[x] < dep[y]) swap(x, y);
			modify(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, vn[x] ^ z);
			vn[x] = z;
		}
	}
	cerr << TIME << "ms\n";
	return 0;
}
posted @ 2023-11-21 16:41  Pengzt  阅读(9)  评论(0编辑  收藏  举报