[2023四校联考3]sakuya

[2023四校联考3]sakuya

题意

给出一棵 \(n\) 个点的树,有 \(m\) 个特殊点 \(a\),求将 \(a\) 随机打乱后

\[\sum_{i=2}^m d(a_{i-1},a_i) \bmod 998244353 \]

的期望。有 \(q\) 次修改,每次将一个点连接的所有边权值增加。

思路

发现期望可以变为求和。

\(S\) 为所有情况的和,\(\frac{S}{m!}\) 就是答案。

如何求出 \(S\) 呢?

考虑每个 \(d(a_{i-1},a_i)\)\(S\) 的贡献。

发现只有当 \(a_{i-1},a_i\) 相邻时,\(d(a_{i-1},a_i)\)\(S\) 有贡献。

次数为 \(2!\times(m-1)!\),用捆绑法的思路把 \(a_{i-1},a_i\) 看成一个整体,再排列。

所以答案为

\[\frac{2 \times (m-1)!\sum_{i=1}^{m} \sum_{j=i+1}^{m} d(i,j)}{m!} \]

考虑如何求出 \(\sum_{i=1}^{m} \sum_{j=i+1}^{m} d(i,j)\),并考虑如何支持修改。

我们可以统计出所有边的经过次数,每个边的次数乘上边权求和就是答案。

而且这样也支持修改,因为树的形态没有改变,每条边经过的次数也是固定的。

做加法时乘上次数即可。

求解每条边经过的次数可以使用树上查分,记 \(c\) 为树上差分数组。

在树上进行 dfs 时。每到一个点,它的子树两两之间对次数都有贡献。

\(x\) 有两棵子树 \(v_1\)\(v_2\),对差分的贡献为:

\(s_{x}\)\(x\) 子树内特殊点的个数,\(v_1\) 子树内的特殊点的 \(c\) 加上 \(s_{v_2}\)\(v_2\) 子树内的特殊点的 \(s_{v_1}\)

\(c_x\) 减去 \(s_{v_1}+s_{v_2}\)。这几步操作相当于把两个子树的特殊点两两连边,而子树内部的情况可以递归处理。

如果 \(x\) 本身是特殊点,还要把子树内每个特殊点的 \(c\) 加一,\(c_x\) 减去 \(s_x\)

一个 \(x\) 若有多个子树,可逐一计算贡献,计算完后把两棵子树合并为一棵,再继续计算下一棵子树的贡献。

但这样的时间复杂度是:\(O(n^2)\),如何优化呢?

如果只记录特殊点的 dfn,则子树内特殊点的 dfn 是连续的,可以使用线段树为维护差分值。

时间复杂度:\(O(n\log n)\)

代码

#include <bits/stdc++.h>
#define int ll
#define inv(x) (qpow(x%mod,mod-2))
using namespace std;

using ll = long long;
const int N = 5e5 + 5;
const ll mod = 998244353;

int tot, ver[N << 1], nxt[N << 1], head[N], edge[N << 1];
int n, m, q, g[N], dfn[N], cnt, L[N], siz[N], ef[N];
bool G[N], in[N];
vector <pair <int, int>> E[N];
ll sum, c[N], disSum, d[N], ans, fac[N]; 

ll qpow(ll x, ll y) {
	ll res = 1;
	for (; y; y >>= 1, x = x * x % mod) 
		if (y & 1) res = res * x % mod;	
	return res;
}

struct segt {
	struct node {
		int l, r;
		ll ad, sum;
	} t[N << 2];
	
	#define ls (p << 1)
	#define rs (p << 1 |1)
	
	void build(int p, int l, int r) {
		t[p].l = l, t[p].r = r;
		if (l == r) return ;
		int mid = (l + r) >> 1;
		build(ls, l, mid);
		build(rs, mid + 1, r);
	} 
	
	void make(int p, ll v) {
		t[p].ad += v;
		t[p].sum += (t[p].r - t[p].l + 1) * v % mod;
		t[p].sum %= mod, t[p].ad %= mod;
	}
	
	void push_down(int p) {
		if (t[p].ad) {
			make(ls, t[p].ad);
			make(rs, t[p].ad);
			t[p].ad = 0;
		}
	}
	
	void add(int p, int l, int r, ll v) {
		if (l <= t[p].l && t[p].r <= r) {
			make(p, v);
			return ;
		}		
		push_down(p);
		if (t[ls].r >= l) add(ls, l, r, v);
		if (t[rs].l <= r) add(rs, l, r, v);
		t[p].sum = t[ls].sum + t[rs].sum; 
		t[p].sum %= mod;
	}
	
	ll query(int p, int id) {
		if (t[p].l == t[p].r) return t[p].sum;
		push_down(p);
		if (id <= t[ls].r) return query(ls, id);
		else return query(rs, id);
	}
} T; 

void add(int x, int y, int z) {
	ver[++ tot] = y;
	nxt[tot] = head[x];
	head[x] = tot;
	edge[tot] = z;
}

bool dfs1(int x, int fa) {
	bool res = 0;
	if (G[x]) siz[x] = 1;
	for (int i = head[x], y; i; i = nxt[i]) {
		y = ver[i];
		if (y == fa) {
			continue;
		}
		res |= dfs1(y, x);
		siz[x] += siz[y];
	}
	if (res) {
		in[x] = 1;
		for (int i = head[x], y; i; i = nxt[i]) {
			y = ver[i];
			if (y == fa) continue;
			if (!siz[y]) continue;
			E[x].push_back({y, edge[i]});
			E[y].push_back({x, edge[i]});
		}
	}
	if (G[x]) in[x] = 1;
	return res | G[x];
}

void dfs2(int x, int fa) {
	L[x] = 1e9;
	if (G[x]) {
		dfn[x] = ++ cnt;
		L[x] = dfn[x];	
	}
	for (auto e : E[x]) {
		int y = e.first, z = e.second;
		if (y == fa) continue;
		sum += z; 
		dfs2(y, x);
		L[x] = min(L[x], L[y]);
	}
} 

void dfs3(int x, int fa) {
	int nowSize = 0, nowL = 1e9;
	for (auto e : E[x]) {
		int y = e.first;
		if (y == fa) continue;
		dfs3(y, x);
		if (!nowSize) {
			nowSize += siz[y];
			nowL = min(nowL, L[y]);
			continue;
		}
		T.add(1, nowL, nowL + nowSize - 1, siz[y]);
		T.add(1, L[y], L[y] + siz[y] - 1, nowSize);
		c[x] += -siz[y] * nowSize - nowSize * siz[y];
		nowSize += siz[y];
		nowL = min(nowL, L[y]);
	}
	if (G[x]) {
		for (auto e : E[x]) {
			int y = e.first;
			if (y == fa) continue;
			T.add(1, L[y], L[y] + siz[y] - 1, 1);
			c[x] -= siz[y];
		}
	}
	
} 
 
void dfs4(int x, int fa) {
	if (dfn[x]) {
		c[x] += T.query(1, dfn[x]);
	}
	for (auto e : E[x]) {
		int y = e.first, z = e.second;
		if (y == fa)  {
			ef[x] = z;
			continue;	
		}	
		dfs4(y, x);
		c[x] += c[y];
		c[x] %= mod;
	}
	disSum = (disSum + c[x] * ef[x]) % mod;
}

void dfs5(int x, int fa) {
	d[x] = c[x];
	for (auto e : E[x]) {
		int y = e.first;
		if (y == fa)  {
			continue;	
		}	
		dfs5(y, x);
		d[x] += c[y];
		d[x] %= mod;
	}
}

signed main() {
	freopen("sakuya.in", "r", stdin);
	freopen("sakuya.out", "w", stdout);
	cin >> n >> m;
	for (int i = 1, x, y, z; i < n; i ++) {
		cin >> x >> y >> z;
		add(x, y, z);
		add(y, x, z);
	}
	fac[0] = 1;
	for (int i = 1; i <= m; i ++) {
		cin >> g[i];
		G[g[i]] = 1;
		fac[i] = fac[i - 1] * i % mod;
	} 	
	dfs1(1, 0); dfs2(1, 0);
	T.build(1, 1, n);
	dfs3(1, 0); dfs4(1, 0); dfs5(1, 0);
	cin >> q;
	while (q --) {
		int x, k;
		cin >> x >> k;
		disSum += k * d[x] % mod, disSum %= mod; 
		ans = 2 * fac[m - 1] % mod * disSum % mod * inv(fac[m]) % mod;
		ans = (ans % mod + mod) % mod;
		cout << ans << "\n";
	}
	return 0;
}
posted @ 2024-09-27 08:49  maniubi  阅读(7)  评论(0编辑  收藏  举报