[2023四校联考3]sakuya
[2023四校联考3]sakuya
题意
给出一棵 \(n\) 个点的树,有 \(m\) 个特殊点 \(a\),求将 \(a\) 随机打乱后
的期望。有 \(q\) 次修改,每次将一个点连接的所有边权值增加。
思路
发现期望可以变为求和。
记 \(S\) 为所有情况的和,\(\frac{S}{m!}\) 就是答案。
如何求出 \(S\) 呢?
考虑每个 \(d(a_{i-1},a_i)\) 对 \(S\) 的贡献。
发现只有当 \(a_{i-1},a_i\) 相邻时,\(d(a_{i-1},a_i)\) 对 \(S\) 有贡献。
次数为 \(2!\times(m-1)!\),用捆绑法的思路把 \(a_{i-1},a_i\) 看成一个整体,再排列。
所以答案为
考虑如何求出 \(\sum_{i=1}^{m} \sum_{j=i+1}^{m} d(i,j)\),并考虑如何支持修改。
我们可以统计出所有边的经过次数,每个边的次数乘上边权求和就是答案。
而且这样也支持修改,因为树的形态没有改变,每条边经过的次数也是固定的。
做加法时乘上次数即可。
求解每条边经过的次数可以使用树上查分,记 \(c\) 为树上差分数组。
在树上进行 dfs 时。每到一个点,它的子树两两之间对次数都有贡献。
若 \(x\) 有两棵子树 \(v_1\) 和 \(v_2\),对差分的贡献为:
记 \(s_{x}\) 为 \(x\) 子树内特殊点的个数,\(v_1\) 子树内的特殊点的 \(c\) 加上 \(s_{v_2}\),\(v_2\) 子树内的特殊点的 \(s_{v_1}\)。
\(c_x\) 减去 \(s_{v_1}+s_{v_2}\)。这几步操作相当于把两个子树的特殊点两两连边,而子树内部的情况可以递归处理。
如果 \(x\) 本身是特殊点,还要把子树内每个特殊点的 \(c\) 加一,\(c_x\) 减去 \(s_x\)。
一个 \(x\) 若有多个子树,可逐一计算贡献,计算完后把两棵子树合并为一棵,再继续计算下一棵子树的贡献。
但这样的时间复杂度是:\(O(n^2)\),如何优化呢?
如果只记录特殊点的 dfn,则子树内特殊点的 dfn 是连续的,可以使用线段树为维护差分值。
时间复杂度:\(O(n\log n)\)。
代码
#include <bits/stdc++.h>
#define int ll
#define inv(x) (qpow(x%mod,mod-2))
using namespace std;
using ll = long long;
const int N = 5e5 + 5;
const ll mod = 998244353;
int tot, ver[N << 1], nxt[N << 1], head[N], edge[N << 1];
int n, m, q, g[N], dfn[N], cnt, L[N], siz[N], ef[N];
bool G[N], in[N];
vector <pair <int, int>> E[N];
ll sum, c[N], disSum, d[N], ans, fac[N];
ll qpow(ll x, ll y) {
ll res = 1;
for (; y; y >>= 1, x = x * x % mod)
if (y & 1) res = res * x % mod;
return res;
}
struct segt {
struct node {
int l, r;
ll ad, sum;
} t[N << 2];
#define ls (p << 1)
#define rs (p << 1 |1)
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if (l == r) return ;
int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
}
void make(int p, ll v) {
t[p].ad += v;
t[p].sum += (t[p].r - t[p].l + 1) * v % mod;
t[p].sum %= mod, t[p].ad %= mod;
}
void push_down(int p) {
if (t[p].ad) {
make(ls, t[p].ad);
make(rs, t[p].ad);
t[p].ad = 0;
}
}
void add(int p, int l, int r, ll v) {
if (l <= t[p].l && t[p].r <= r) {
make(p, v);
return ;
}
push_down(p);
if (t[ls].r >= l) add(ls, l, r, v);
if (t[rs].l <= r) add(rs, l, r, v);
t[p].sum = t[ls].sum + t[rs].sum;
t[p].sum %= mod;
}
ll query(int p, int id) {
if (t[p].l == t[p].r) return t[p].sum;
push_down(p);
if (id <= t[ls].r) return query(ls, id);
else return query(rs, id);
}
} T;
void add(int x, int y, int z) {
ver[++ tot] = y;
nxt[tot] = head[x];
head[x] = tot;
edge[tot] = z;
}
bool dfs1(int x, int fa) {
bool res = 0;
if (G[x]) siz[x] = 1;
for (int i = head[x], y; i; i = nxt[i]) {
y = ver[i];
if (y == fa) {
continue;
}
res |= dfs1(y, x);
siz[x] += siz[y];
}
if (res) {
in[x] = 1;
for (int i = head[x], y; i; i = nxt[i]) {
y = ver[i];
if (y == fa) continue;
if (!siz[y]) continue;
E[x].push_back({y, edge[i]});
E[y].push_back({x, edge[i]});
}
}
if (G[x]) in[x] = 1;
return res | G[x];
}
void dfs2(int x, int fa) {
L[x] = 1e9;
if (G[x]) {
dfn[x] = ++ cnt;
L[x] = dfn[x];
}
for (auto e : E[x]) {
int y = e.first, z = e.second;
if (y == fa) continue;
sum += z;
dfs2(y, x);
L[x] = min(L[x], L[y]);
}
}
void dfs3(int x, int fa) {
int nowSize = 0, nowL = 1e9;
for (auto e : E[x]) {
int y = e.first;
if (y == fa) continue;
dfs3(y, x);
if (!nowSize) {
nowSize += siz[y];
nowL = min(nowL, L[y]);
continue;
}
T.add(1, nowL, nowL + nowSize - 1, siz[y]);
T.add(1, L[y], L[y] + siz[y] - 1, nowSize);
c[x] += -siz[y] * nowSize - nowSize * siz[y];
nowSize += siz[y];
nowL = min(nowL, L[y]);
}
if (G[x]) {
for (auto e : E[x]) {
int y = e.first;
if (y == fa) continue;
T.add(1, L[y], L[y] + siz[y] - 1, 1);
c[x] -= siz[y];
}
}
}
void dfs4(int x, int fa) {
if (dfn[x]) {
c[x] += T.query(1, dfn[x]);
}
for (auto e : E[x]) {
int y = e.first, z = e.second;
if (y == fa) {
ef[x] = z;
continue;
}
dfs4(y, x);
c[x] += c[y];
c[x] %= mod;
}
disSum = (disSum + c[x] * ef[x]) % mod;
}
void dfs5(int x, int fa) {
d[x] = c[x];
for (auto e : E[x]) {
int y = e.first;
if (y == fa) {
continue;
}
dfs5(y, x);
d[x] += c[y];
d[x] %= mod;
}
}
signed main() {
freopen("sakuya.in", "r", stdin);
freopen("sakuya.out", "w", stdout);
cin >> n >> m;
for (int i = 1, x, y, z; i < n; i ++) {
cin >> x >> y >> z;
add(x, y, z);
add(y, x, z);
}
fac[0] = 1;
for (int i = 1; i <= m; i ++) {
cin >> g[i];
G[g[i]] = 1;
fac[i] = fac[i - 1] * i % mod;
}
dfs1(1, 0); dfs2(1, 0);
T.build(1, 1, n);
dfs3(1, 0); dfs4(1, 0); dfs5(1, 0);
cin >> q;
while (q --) {
int x, k;
cin >> x >> k;
disSum += k * d[x] % mod, disSum %= mod;
ans = 2 * fac[m - 1] % mod * disSum % mod * inv(fac[m]) % mod;
ans = (ans % mod + mod) % mod;
cout << ans << "\n";
}
return 0;
}
本文来自博客园,作者:maniubi,转载请注明原文链接:https://www.cnblogs.com/maniubi/p/18434938,orz