Luogu 5439 XR-2永恒
\(T\) 是节点数为 \(n\) 的那棵树,\(T'\) 是 Trie 树。带 \('\) 的,比如 \(\text{dep}'_u\),表示 Trie 上的信息(注意到 \(\text{dep}'\) 要从 \(0\) 开始),不带的表示原树。\([u,v]\) 表示 \(u\to v\) 的路径,\(S\) 是原树上无序点对的全集。
那么答案就是:
考虑枚举 \((u,v)\) 分情况计算贡献:
- \(\text{lca(u,v)}\neq u,v\):显然包含 \([u,v]\) 的路径数为 \(\text{siz}_u\text{siz}_v\)
- 否则设 \(u\) 为 \(v\) 的祖先,\(w\) 为 \(u\) 向 \(v\) 延伸的那个儿子,那么经过 \([u,v]\) 的路径数为 \((n-\text{siz}_w)\text{siz}_v\)。
考虑第一种贡献咋算,根据某道经典题的套路,一个简单的想法是枚举 \(u\),然后 \([d_u\to \text{root}']\) 上面所有点增加 \(\text{siz}_u\),然后枚举 \(v\),对 \([d_v\to\text{root}']\) 求和乘上 \(\text{siz}_v\) 即可。但是要除掉根的贡献,因为 \(\text{dep}'_{d_\text{root}}=0\)。
然后对于第二种贡献,我们直接对 \(T\) 搜一边,到了一个点 \(v\),考虑它到 \(\text{root}\) 的路径对 \(v\) 的贡献和。我们到一个点 \(u\),枚举出边 \(u\to w\),直接给 \([u,\text{root}]\) 加上 \(n-\text{siz}_w\),然后往 \(w\) 走,回溯的时候再减去即可。
但是这样会算重,发现在第一种贡献中,在 \(u\) 子树内的 \(v\) 也被计算了。冷静分析一下,因为算重的 \((u,v)\) 均满足 \(v\) 在 \(u\) 子树内,所以我们做第二种贡献的时候将增量 \(n-\text{siz}_w\) 变成 \(n-\text{siz}_w-\text{siz}_u\) 即可。
然后线段树随便做了,复杂度 \(O(n\log^2 n)\)。自认为写得很清新。
#include <bits/stdc++.h>
using namespace std;
namespace vbzIO {
char ibuf[(1 << 20) + 1], *iS, *iT;
#if ONLINE_JUDGE
#define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
#else
#define gh() getchar()
#endif
#define pc putchar
#define pi pair<int, int>
#define tu3 tuple<int, int, int>
#define tu4 tuple<int, int, int, int>
#define mp make_pair
#define mt make_tuple
#define fi first
#define se second
#define pb push_back
#define ins insert
#define era erase
inline int read () {
char ch = gh();
int x = 0;
bool t = 0;
while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gh();
return t ? ~(x - 1) : x;
}
inline void write(int x) {
if (x < 0) {
x = ~(x - 1);
putchar('-');
}
if (x > 9)
write(x / 10);
putchar(x % 10 + '0');
}
}
using vbzIO::read;
using vbzIO::write;
const int mod = 998244353;
const int maxn = 3e5 + 300;
const int inv2 = (mod + 1) / 2;
struct seg { int sum, tag; } tr[maxn << 2];
int n, m, rt, dfc, ans, fa[maxn], a[maxn];
int sz[maxn], dep[maxn], siz[maxn], fath[maxn], son[maxn], top[maxn], id[maxn], d[maxn];
vector<int> t1[maxn], t2[maxn];
#define ls x << 1
#define rs x << 1 | 1
#define mid ((l + r) >> 1)
void pushup(int x) { tr[x].sum = (tr[ls].sum + tr[rs].sum) % mod; }
void pushtag(int x, int c, int l, int r) { (tr[x].sum += 1ll * c * (r - l + 1) % mod) %= mod, (tr[x].tag += c) %= mod; }
void pushdown(int x, int l, int r) {
if (!tr[x].tag) return;
pushtag(ls, tr[x].tag, l, mid), pushtag(rs, tr[x].tag, mid + 1, r);
tr[x].tag = 0;
}
void upd(int l, int r, int s, int t, int c, int x) {
if (s <= l && r <= t) return pushtag(x, c, l, r);
pushdown(x, l, r);
if (s <= mid) upd(l, mid, s, t, c, ls);
if (t > mid) upd(mid + 1, r, s, t, c, rs);
pushup(x);
}
int qry(int l, int r, int s, int t, int x) {
if (s <= l && r <= t) return tr[x].sum;
int res = 0;
pushdown(x, l, r);
if (s <= mid) (res += qry(l, mid, s, t, ls)) %= mod;
if (t > mid) (res += qry(mid + 1, r, s, t, rs)) %= mod;
return res;
}
void updp(int u, int c) {
c %= mod;
while (u) {
upd(0, m, id[top[u]], id[u], c, 1);
u = fath[top[u]];
}
}
int qryp(int u) {
int res = 0;
while (u) {
(res += qry(0, m, id[top[u]], id[u], 1)) %= mod;
u = fath[top[u]];
}
return res;
}
void dfs1(int u, int fat) {
siz[u] = 1, dep[u] = dep[fat] + 1;
for (int v : t2[u]) {
dfs1(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int pre) {
top[u] = pre, id[u] = ++dfc;
if (son[u]) dfs2(son[u], pre);
for (int v : t2[u]) {
if (v == son[u]) continue;
dfs2(v, v);
}
}
void dfs3(int u) {
sz[u] = 1;
for (int v : t1[u])
dfs3(v), sz[u] += sz[v];
}
void dfs4(int u) {
for (int v : t2[u])
dfs4(v), (d[u] += d[v]) %= mod;
}
void dfs5(int u, int fat) {
(d[u] += d[fat]) %= mod;
for (int v : t2[u]) dfs5(v, u);
}
void dfs6(int u) {
(ans += 1ll * sz[u] * qryp(a[u]) % mod) %= mod;
for (int v : t1[u]) {
updp(a[u], mod + n - sz[u] - sz[v]);
dfs6(v);
updp(a[u], mod - n + sz[u] + sz[v]);
}
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i++) {
fa[i] = read();
if (!fa[i]) rt = i;
else t1[fa[i]].pb(i);
}
for (int i = 1; i <= m; i++) {
fath[i] = read();
if (fath[i]) t2[fath[i]].pb(i);
}
scanf("%*s");
for (int i = 1; i <= n; i++) a[i] = read();
for (int u : t2[1]) dep[u] = 1, fath[u] = 0, dfs1(u, 0), dfs2(u, u);
dfs3(rt);
for (int i = 1; i <= n; i++) (d[a[i]] += sz[i]) %= mod;
dfs4(1), d[1] = 0, dfs5(1, 0);
for (int i = 1; i <= n; i++) {
ans = (ans + 1ll * sz[i] * d[a[i]] % mod) % mod;
ans = (ans + mod - 1ll * sz[i] * sz[i] % mod * dep[a[i]] % mod) % mod;
}
ans = 1ll * ans * inv2 % mod;
dfs6(rt);
write(ans);
return 0;
}