牛客练习赛81D 小Q与树

dsu on tree

题目链接

点我跳转

题目大意

给定一棵包含 \(n\) 个节点的树,每个节点有个权值 \(a_i\)
\(\sum_{u=1}^n\sum_{v=1}^n\min(a_u,a_v)dis(u,v)\)

解题思路

对于节点 \(u\)

  • 记权值小于 \(a_u\) 的节点有 \(a_{x1},a_{x2},a_{x3},...,a_{xcnt1}\)
  • 记权值大于等于 \(a_u\) 的节点有 \(a_{y1},a_{y2},...,a_{ycnt2}\)

那么节点 \(u\) 对答案的贡献为:

  1. \(a_u\times(dep_u + dep_{x1} - 2\times dep_{lca})+a_u\times(dep_u + dep_{x2} - 2\times dep_{lca})+...\)
  2. \(a_{y1}\times(dep_u + dep_{y1} - 2\times dep_{lca})+a_{y2}\times(dep_u + dep_{y2} - 2\times dep_{lca})+...\)

即:

  1. \(a_u\times cnt1\times (dep_u -2\times dep_{lca}) + a_u\times(deq_{(x1+...+xcnt1)})\)
  2. \(a_{(y1+...+ycnt2)}\times (dep_u - 2\times dep_{lca})+a_{(y1+...+ycnt2)}\times dep_{(y1+..+ycnt2)}\)

定义 \(rt\) 为当前子树的根,那么 \(lca = rt\)

开四棵权值树状数组,分别用来维护 \(cnt\)\(dep\)\(a_i\)\(a_i\times dep_i\)

然后跑一遍 \(dsu~on~tree\) 即可

AC_Code

#include<bits/stdc++.h>
#define int long long
using namespace std;
template<typename T>void read(T &res)
{
	bool flag=false;
	char ch;
	while(!isdigit(ch=getchar()))(ch=='-')&&(flag=true);
	for(res=ch-48; isdigit(ch=getchar()); res=(res<<1)+(res<<3)+ch - 48);
	flag&&(res=-res);
}
template<typename T>void Out(T x)
{
	if(x<0)putchar('-'),x=-x;
	if(x>9)Out(x/10);
	putchar(x%10+'0');
}
const int N = 2e5 + 10 , mod = 998244353;
int n , ans , a[N] , dep[N] , sz[N] , HH , hson[N] , M;
struct Edge{
	int nex , to;
} edge[N << 1];
int head[N] , TOT;
void add_edge(int u , int v)
{
	edge[++ TOT].nex = head[u];
	edge[TOT].to = v;
	head[u] = TOT;
}
struct TR{
	int tr[N];
	int lowbit(int x){
		return x & (-x);
	}
	void add(int pos , int val)
	{
		while(pos <= M)
		{
			tr[pos] = (tr[pos] + val + mod) % mod;
			pos += lowbit(pos);
		}
	}
	int query(int pos)
	{
		int res = 0;
		while(pos)
		{
			res += tr[pos];
			res %= mod;
			pos -= lowbit(pos);
		}
		return res;
	}
	int get_sum(int L , int R){
		return (query(R) - query(L - 1) + mod) % mod;
	}
} tree1 , tree2 , tr1 , tr2;
vector<int>vec;
int get_id(int x){
	return lower_bound(vec.begin() , vec.end() , x) - vec.begin() + 1;
}
void dfs(int u , int far)
{
	dep[u] = dep[far] + 1 , sz[u] = 1;
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far) continue ;
		dfs(v , u);
		sz[u] += sz[v];
		if(sz[v] > sz[hson[u]]) hson[u] = v;
	}
}
void change(int u , int far , int val)
{
	tree1.add(a[u] , dep[u] * val);
	tree2.add(a[u] , vec[a[u] - 1] * dep[u] * val);
	tr1.add(a[u] , val);
	tr2.add(a[u] , val * vec[a[u] - 1]);
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == HH) continue ;
		change(v , u , val);
	}
}
void calc(int u , int far , int rt)
{
	int cnt = tr1.get_sum(a[u] , M);
	int sum = tree1.get_sum(a[u] , M);
	int mi = vec[a[u] - 1];
		ans += mi * dep[u] * cnt + mi * sum;
		ans -= mi * cnt * 2 * dep[rt];
		ans = (ans + mod) % mod;
	sum = tree2.get_sum(1 , a[u] - 1);
	cnt = tr2.get_sum(1 , a[u] - 1);
		ans += sum + cnt * dep[u];
		ans -= cnt * 2 * dep[rt];
		ans = (ans + mod) % mod;
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == HH) continue ;
		calc(v , u , rt);
	}
}
void dsu(int u , int far , int op)
{
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == hson[u]) continue ;
		dsu(v , u , 0);
	}
	if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u];
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == HH) continue;
		calc(v , u , u) , change(v , u , 1);
	}
	int cnt = tr1.get_sum(a[u] , M);
	int sum = tree1.get_sum(a[u] , M);
	int mi = vec[a[u] - 1];
		ans += mi * dep[u] * cnt + mi * sum;
		ans -= mi * cnt * 2 * dep[u];
		ans = (ans + mod) % mod;
	sum = tree2.get_sum(1 , a[u] - 1);
	cnt = tr2.get_sum(1 , a[u] - 1);
		ans += sum + cnt * dep[u];
		ans -= cnt * 2 * dep[u];
		ans = (ans + mod) % mod;
	tree1.add(a[u] , dep[u]);
	tree2.add(a[u] , vec[a[u] - 1] * dep[u]);
	tr1.add(a[u] , 1);
	tr2.add(a[u] , vec[a[u] - 1]);
	HH = 0;
	if(!op) change(u , far , -1);
}
signed main()
{
	read(n);
	for(int i = 1 ; i <= n ; i ++) read(a[i]) , vec.push_back(a[i]);
	for(int i = 1 ; i <  n ; i ++)
	{
		int u , v;
		read(u) , read(v);
		add_edge(u , v) , add_edge(v , u);
	}
	sort(vec.begin() , vec.end());
	vec.erase(unique(vec.begin() , vec.end()) , vec.end());
	for(int i = 1 ; i <= n ; i ++) a[i] = get_id(a[i]);
	M = vec.size();
	dfs(1 , 0);
	dsu(1 , 0 , 1);
	Out(ans * 2 % mod) , puts("");
	return 0;
}
posted @ 2021-07-01 12:07  GsjzTle  阅读(170)  评论(1编辑  收藏  举报