2020-2021 “Orz Panda” Cup Programming Contest G题(树形结构)
题目大意:给点一颗包含 \(n\)个节点的无根树,有 \(m\)次询问,每次询问给出两个点 \(u\)和 \(v\),要求计算
\(d_{r}(u,v)\)是以 \(r\)为根的树上 \(u\)到 \(v\)的“美丽路径”,它的定义为:
其中 \(lca_{r}(u,v)\)是以节点 \(r\)为根的树中,点 \(u\)和点 \(v\)的最近公共祖先。\(dis(u,v)\)等于 \(u\),\(v\)之间最短路径的边数。
输入:第一行输入 \(n,m\),接下来 \(n-1\)行给出连边情况,接下来 \(m\)行代表 \(m\)组询问。
输出:对于每个询问输出答案对998244353取模
数据范围:\(1 \leq n,m \leq 1e5\)
分析:令节点 \(1\)为根,简化问题。考虑要算的东西,发现它只与 \(u\)到 \(v\)的路径上的节点以及这些节点的“分支节点”有关。不明白的话可以画图具体算一下。考虑点 \(u\)到 \(lca\)上的节点 \(u_{1},u_{2}...u_{k}\),假设 \(u_{p}\)为 \(u_{k}\)的“分支节点”,那么无论是以 \(u_{k}\)为根还是以 \(u_{p}\)为根, \(lca(u,v)\)都等于 \(u_{k}\),也就是说可以把 \(u_{k}\)的“分支节点”对答案的贡献累加到 \(u_{k}\)上。假设原本 \(u_{k}\)对答案的贡献为 \(w\),那么现在就等于 \((num+1) \cdot w\),\(num\)是“分支节点”的个数,设 \(siz[x]\)是以 \(1\)为根的树中以 \(x\)为根的子树大小,那么 \(num=siz[u_{k}]-siz[u_{k-1}]\),设 \(u,v\)之间的距离为 \(dis\),\(dis=dep[u]+dep[v]-2 \times dep[lca]\)。那么 \(u\)到 \(lca\)上的节点 \(u_{1},u_{2}...u_{k}\)对答案的贡献就等于
把它拆成 \(8\)项,分别计算就好,求下前缀和就可以 \(O(1)\)计算。对于 \(v\)到 \(lca\)的那部分贡献同理计算。另外 \(lca\)对答案的贡献需要另算。
#include<cstdio>
typedef long long ll;
const int N = 1e5 + 5;
const int mod = 998244353;
int n, m, cnt, son_u, son_v;
int head[N], dep[N], son[N], fa[N], top[N];
ll d_siz[N], d2_siz[N], fa_d_siz[N], fa_d2_siz[N], siz[N];
// son_u表示u到lca路径上离lca最近的点,son_v同理
// d_siz[x] = dep[x] * siz[x]
// d2_siz[x] = dep[x] * dep[x] * siz[x]
// fa_d_siz[x] = dep[fa[x]] * siz[x]
// da_d2_siz[x] = dep[fa[x]] * dep[fa[x]] * siz[x]
struct Edge{
int nex, to;
}e[N << 1];
inline ll max(ll a, ll b) { return a > b ? a : b; }
inline void add(int a, int b) { e[++cnt] = {head[a], b}; head[a] = cnt; }
void dfs1(int u, int f){
dep[u] = dep[f] + 1, fa[u] = f, siz[u] = 1;
for(int i = head[u]; i; i = e[i].nex){
int to = e[i].to;
if(to == f) continue;
dfs1(to, u);
if(siz[to] > siz[son[u]]) son[u] = to;
siz[u] += siz[to];
}
}
void dfs2(int u, int ttop){
top[u] = ttop;
if(son[u]) dfs2(son[u], ttop);
for(int i = head[u]; i; i = e[i].nex){
int to = e[i].to;
if(to == fa[u] || to == son[u]) continue;
dfs2(to, to);
}
}
void dfs3(int u, int f){
d_siz[u] = (1LL * dep[u] * siz[u] + d_siz[f]) % mod;
d2_siz[u] = (1LL * dep[u] * dep[u] % mod * siz[u] + d2_siz[f]) % mod;
fa_d_siz[u] = (1LL * dep[fa[u]] * siz[u] + fa_d_siz[f]) % mod;
fa_d2_siz[u] = (1LL * dep[fa[u]] * dep[fa[u]] % mod * siz[u] + fa_d2_siz[f]) % mod;
for(int i = head[u]; i; i = e[i].nex){
int to = e[i].to;
if(to == f) continue;
dfs3(to, u);
}
}
// 找lca和son_u,son_v
int get_lca(int u, int v){
while(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]) son_u = top[u], u = fa[top[u]];
else son_v = top[v], v = fa[top[v]];
}
if(dep[u] > dep[v]) son_u = son[v];
else son_v = son[u];
return dep[u] > dep[v] ? v : u;
}
ll cal(int u, int v, ll *p){
// u或v等于0说明son_u不存在,返回0
return (dep[u] < dep[v] || v == 0 || u == 0) ? 0 : p[u] - p[v];
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 1, u, v; i < n; ++i){
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
dfs1(1, 0);
dfs2(1, 1);
dfs3(1, 0);
for(int i = 1, u, v; i <= m; ++i){
scanf("%d%d", &u, &v);
ll lca = get_lca(u, v), dis = dep[u] + dep[v] - (dep[lca] << 1);
if(u == lca) son_u = 0;
if(v == lca) son_v = 0;
ll ans = 1LL * (n - siz[son_u] - siz[son_v]) * (dep[u] - dep[lca]) % mod * (dep[v] - dep[lca]) % mod; // lca的贡献
ans -= ((dis - dep[u]) * cal(fa[u], lca, d_siz) + (dis - dep[v]) * cal(fa[v], lca, d_siz)) % mod;
ans += ((dis - dep[u]) * dep[u] % mod * max(0, siz[son_u] - siz[u]) + (dis - dep[v]) * dep[v] % mod * max(0, siz[son_v] - siz[v])) % mod;
ans += ((dis - dep[u]) * cal(u, son_u, fa_d_siz) + (dis - dep[v]) * cal(v, son_v, fa_d_siz)) % mod;
ans -= cal(fa[u], lca, d2_siz) + cal(fa[v], lca, d2_siz);
ans += cal(u, son_u, fa_d2_siz) + cal(v, son_v, fa_d2_siz);
ans += (dep[u] * cal(fa[u], lca, d_siz) + dep[v] * cal(fa[v], lca, d_siz)) % mod;
ans -= (dep[u] * cal(u, son_u, fa_d_siz) + dep[v] * cal(v, son_v, fa_d_siz)) % mod;
printf("%lld\n", (ans % mod + mod) % mod);
}
return 0;
}