2020-2021 “Orz Panda” Cup Programming Contest G题(树形结构)

题目传送门

题目大意:给点一颗包含 \(n\)个节点的无根树,有 \(m\)次询问,每次询问给出两个点 \(u\)\(v\),要求计算

\[\sum_{r=1}^{n}d_{r}(u,v) \]

\(d_{r}(u,v)\)是以 \(r\)为根的树上 \(u\)\(v\)的“美丽路径”,它的定义为:

\[d_{r}(u,v)=dis(u,lca_{r}(u,v)) \times dis(v,lca_{r}(u,v)) \]

其中 \(lca_{r}(u,v)\)是以节点 \(r\)为根的树中,点 \(u\)和点 \(v\)的最近公共祖先。\(dis(u,v)\)等于 \(u\)\(v\)之间最短路径的边数。

输入:第一行输入 \(n,m\),接下来 \(n-1\)行给出连边情况,接下来 \(m\)行代表 \(m\)组询问。

输出:对于每个询问输出答案对998244353取模

数据范围\(1 \leq n,m \leq 1e5\)

分析:令节点 \(1\)为根,简化问题。考虑要算的东西,发现它只与 \(u\)\(v\)的路径上的节点以及这些节点的“分支节点”有关。不明白的话可以画图具体算一下。考虑点 \(u\)\(lca\)上的节点 \(u_{1},u_{2}...u_{k}\),假设 \(u_{p}\)\(u_{k}\)的“分支节点”,那么无论是以 \(u_{k}\)为根还是以 \(u_{p}\)为根, \(lca(u,v)\)都等于 \(u_{k}\),也就是说可以把 \(u_{k}\)的“分支节点”对答案的贡献累加到 \(u_{k}\)上。假设原本 \(u_{k}\)对答案的贡献为 \(w\),那么现在就等于 \((num+1) \cdot w\)\(num\)是“分支节点”的个数,设 \(siz[x]\)是以 \(1\)为根的树中以 \(x\)为根的子树大小,那么 \(num=siz[u_{k}]-siz[u_{k-1}]\),设 \(u,v\)之间的距离为 \(dis\)\(dis=dep[u]+dep[v]-2 \times dep[lca]\)。那么 \(u\)\(lca\)上的节点 \(u_{1},u_{2}...u_{k}\)对答案的贡献就等于

\[\sum_{r=1}^{k}(siz[u_{k}]-siz[u_{k-1}]) \times (dep[u]-dep[u_{k}]) \times (dis-(dep[u]-dep[u_{k}])) \]

把它拆成 \(8\)项,分别计算就好,求下前缀和就可以 \(O(1)\)计算。对于 \(v\)\(lca\)的那部分贡献同理计算。另外 \(lca\)对答案的贡献需要另算。

#include<cstdio>
typedef long long ll;
const int N = 1e5 + 5;
const int mod = 998244353;

int n, m, cnt, son_u, son_v;
int head[N], dep[N], son[N], fa[N], top[N];
ll d_siz[N], d2_siz[N], fa_d_siz[N], fa_d2_siz[N], siz[N];
// son_u表示u到lca路径上离lca最近的点,son_v同理
// d_siz[x] = dep[x] * siz[x]
// d2_siz[x] = dep[x] * dep[x] * siz[x]
// fa_d_siz[x] = dep[fa[x]] * siz[x]
// da_d2_siz[x] = dep[fa[x]] * dep[fa[x]] * siz[x]

struct Edge{
	int nex, to;
}e[N << 1];

inline ll max(ll a, ll b) { return a > b ? a : b; }
inline void add(int a, int b) { e[++cnt] = {head[a], b};  head[a] = cnt; }

void dfs1(int u, int f){
	dep[u] = dep[f] + 1, fa[u] = f, siz[u] = 1;
	for(int i = head[u]; i; i = e[i].nex){
		int to = e[i].to;
		if(to == f)  continue;
		dfs1(to, u);
		if(siz[to] > siz[son[u]])  son[u] = to;
		siz[u] += siz[to];
	}
}

void dfs2(int u, int ttop){
	top[u] = ttop;
	if(son[u])  dfs2(son[u], ttop);
	for(int i = head[u]; i; i = e[i].nex){
		int to = e[i].to;
		if(to == fa[u] || to == son[u])  continue;
		dfs2(to, to);
	}
}

void dfs3(int u, int f){
	d_siz[u] = (1LL * dep[u] * siz[u] + d_siz[f]) % mod;
	d2_siz[u] = (1LL * dep[u] * dep[u] % mod * siz[u] + d2_siz[f]) % mod;
	fa_d_siz[u] = (1LL * dep[fa[u]] * siz[u] + fa_d_siz[f]) % mod;
	fa_d2_siz[u] = 	(1LL * dep[fa[u]] * dep[fa[u]] % mod * siz[u] + fa_d2_siz[f]) % mod;
	for(int i = head[u]; i; i = e[i].nex){
		int to = e[i].to;
		if(to == f)  continue;
		dfs3(to, u);
	}
}

// 找lca和son_u,son_v
int get_lca(int u, int v){
	while(top[u] != top[v]){
		if(dep[top[u]] > dep[top[v]])  son_u = top[u], u = fa[top[u]];
		else  son_v = top[v], v = fa[top[v]];
	}
	if(dep[u] > dep[v])  son_u = son[v];
	else  son_v = son[u];
	return dep[u] > dep[v] ? v : u;
}

ll cal(int u, int v, ll *p){
	// u或v等于0说明son_u不存在,返回0
	return (dep[u] < dep[v] || v == 0 || u == 0) ? 0 : p[u] - p[v];
}

int main(){
	scanf("%d%d", &n, &m);
	for(int i = 1, u, v; i < n; ++i){
		scanf("%d%d", &u, &v);
		add(u, v),  add(v, u);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	dfs3(1, 0);
	for(int i = 1, u, v; i <= m; ++i){
		scanf("%d%d", &u, &v);
		ll lca = get_lca(u, v),  dis = dep[u] + dep[v] - (dep[lca] << 1);
		if(u == lca)  son_u = 0;
		if(v == lca)  son_v = 0;
		ll ans = 1LL * (n - siz[son_u] - siz[son_v]) * (dep[u] - dep[lca]) % mod * (dep[v] - dep[lca]) % mod;  // lca的贡献
		ans -= ((dis - dep[u]) * cal(fa[u], lca, d_siz) + (dis - dep[v]) * cal(fa[v], lca, d_siz)) % mod;
		ans += ((dis - dep[u]) * dep[u] % mod * max(0, siz[son_u] - siz[u]) + (dis - dep[v]) * dep[v] % mod * max(0, siz[son_v] - siz[v])) % mod;
		ans += ((dis - dep[u]) * cal(u, son_u, fa_d_siz) + (dis - dep[v]) * cal(v, son_v, fa_d_siz)) % mod;
		ans -= cal(fa[u], lca, d2_siz) + cal(fa[v], lca, d2_siz);
		ans += cal(u, son_u, fa_d2_siz) + cal(v, son_v, fa_d2_siz);
		ans += (dep[u] * cal(fa[u], lca, d_siz) + dep[v] * cal(fa[v], lca, d_siz)) % mod;
		ans -= (dep[u] * cal(u, son_u, fa_d_siz) + dep[v] * cal(v, son_v, fa_d_siz)) % mod;
		printf("%lld\n", (ans % mod + mod) % mod);
	}
	return 0;
}
posted @ 2020-12-02 17:23  のNice  阅读(79)  评论(0编辑  收藏  举报