链接:https://ac.nowcoder.com/acm/contest/11171/D
来源:牛客网
时间限制:C/C++ 2秒,其他语言4秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
题目描述:
小 Q 在纸上画树,画着画着,小 Q 在纸上画出了一棵\(n\)个点,\(n−1\) 条边的树,其中他给第 \(i\) 个点都赋了一个点权\(a_i\),每条边的距离为 11。他想要知道
对 998244353 取模后的值。
输入描述:
第一行一个整数\(n(1\le n\le2^{10})\),表示点的个数。
接下来 \(n\) 个整数,其中第 \(i\) 个整数表示 \(a_i(1\le a_i\lt998244353)\)。
最后 \(n-1\) 行,每行两个整数 \(u,v(1\le u,v\le n,u\ne v)\),表示节点 \(u\) 与节点 \(v\) 之间有一条无向边。
输出描述:
一行一个整数,表示答案。
输入
5
3 2 4 6 5
1 2
2 3
2 4
4 5
输出
112
解析
根据题目描述可以知道,需要计算树上任意两个点之间的距离乘以这两个点权值最小值的和,直接肯定会超时复杂度是 \(O(n^2)\),但是可以对树进行分治来求。思想就是根据树的重心对树进行分治,拆成若干子树,对于每次分治的计算方式就是只需要计算通过该重心的任意两个点的值就可以了。如何计算呢?我是只先统计了分治后每个点到该重心的距离dp,然后再统计出每棵子树的所有节点数量,然后把所有节点按照\(a_i\)的权值从小到大进行排序,再从小到大枚举每个节点,因为是从小到大,所以每个节点与其他后面的节点的相比一定是小于等于的,所以只要统计后面还没有统计过的所有距离的和就可以了嘛。不,这样有问题,因为我们只要计算需要通过重心的点的,如果统计后面所有的节点的,就会把同一个子树的节点计算进去了,这样就不对的,所以需要删除同一棵子树的所有距离,这样还缺少了枚举点到重心的距离,所以还需要再加上该枚举点到重心的距离,这里需要注意的是,后面有多少节点,就需要加多次,这样就能够正确求出通过该节点的所有需要求的值的了。然后再依次进行分治即可。时间复杂度,树分治进行计算是\(O(n\cdot logn)\),然后还需要排序,所以总时间复杂度为\(O(n\cdot logn\cdot logn)\),能够满足条件。
代码
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <cstdio>
#include <string>
#include <cstdlib>
#include <cmath>
#include <iostream>
#include <cstring>
#include <set>
#include <queue>
#include <algorithm>
#include <vector>
#include <map>
#include <cctype>
#include <ctime>
#include <stack>
#include <sstream>
#include <list>
#include <assert.h>
#include <bitset>
#include <numeric>
#include <unordered_map>
#define debug() puts("++++")
#define print(x) cout<<"====== "<<(x)<<" ====="<<endl;
// #define gcd(a, b) __gcd(a, b)
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define fi first
#define se second
#define pb push_back
#define sqr(x) ((x)*(x))
#define ms(a,b) memset(a, b, sizeof a)
#define sz size()
#define be begin()
#define ed end()
#define pu push_up
#define pd push_down
#define cl clear()
#define lowbit(x) -x&x
// #define all 1,n,1
#define FOR(i,n,x) for(int i = (x); i < (n); ++i)
#define freopenr freopen("in.in", "r", stdin)
#define freopenw freopen("out.out", "w", stdout)
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> P;
const int INF = 0x3f3f3f3f;
const LL LNF = 1e17;
const double inf = 1e20;
const double PI = acos(-1.0);
const double eps = 1e-8;
const int maxn = 2e5 + 7;
const int maxm = 2000000 + 7;
const LL mod = 998244353;
const int dr[] = {-1, 1, 0, 0, 1, 1, -1, -1};
const int dc[] = {0, 0, 1, -1, 1, -1, 1, -1};
int n, m;
inline bool is_in(int r, int c) {
return r >= 0 && r < n && c >= 0 && c < m;
}
inline int readInt(){
int x; scanf("%d", &x); return x;
}
int a[maxn];
bool vis[maxn];
struct Edge{
int to, next;
};
int head[maxn];
int cnt;
Edge edges[maxn<<1];
void add(){
int u, v; scanf("%d %d", &u, &v);
edges[cnt].to = v;
edges[cnt].next = head[u];
head[u] = cnt++;
edges[cnt].to = u;
edges[cnt].next = head[v];
head[v] = cnt++;
}
int root, total;
int f[maxn], num[maxn];
void dfs_for_root(int u, int fa){
num[u] = 1; f[u] = 0;
for(int i = head[u]; ~i; i = edges[i].next){
int v = edges[i].to;
if(vis[v] || fa == v) continue;
dfs_for_root(v, u);
num[u] += num[v];
f[u] = max(f[u], num[v]);
}
f[u] = max(f[u], total - f[u] - 1);
if(f[u] < f[root]) root = u;
}
int dp[maxn];
vector<P> points;
LL _count[maxn], _num[maxn];
int c;
LL dfs_for_dist(int u, int fa, int rt){
dp[u] = dp[fa] + 1;
LL sum = dp[u];
++c;
points.pb(P(u, rt));
for(int i = head[u]; ~i; i = edges[i].next){
int v = edges[i].to;
if(vis[v] || fa == v) continue;
sum += dfs_for_dist(v, u, rt);
}
return sum;
}
LL solve_for_ans(int u){
points.cl;
LL sum = 0, _total = 0;
for(int i = head[u]; ~i; i = edges[i].next){
int v = edges[i].to;
if(vis[v]) continue;
c = 0;
_count[v] = dfs_for_dist(v, u, v);
_num[v] = c;
_total += c;
sum += _count[v];
}
sort(points.be, points.ed, [&](P &lhs, P &rhs){return a[lhs.fi] < a[rhs.fi];});
LL ans = 0;
for_each(points.be, points.ed, [&](P &p){
int x = p.fi;
ans = (ans + (sum - _count[p.se] + (_total - _num[p.se]) * dp[x] % mod) * a[x] + dp[x] * min(a[x], a[u])) % mod;
sum -= dp[x]; _count[p.se] -= dp[x];
--_num[p.se]; --_total;
});
return ans;
}
LL dfs(int u){
vis[u] = true;
dp[u] = 0;
LL ans = solve_for_ans(u);
for(int i = head[u]; ~i; i = edges[i].next){
int v = edges[i].to;
if(vis[v]) continue;
total = num[v];
root = 0;
dfs_for_root(v, u);
ans = (ans + dfs(root)) % mod;
}
return ans;
}
int main(){
scanf("%d", &n); ms(head, -1);
for(int i = 1; i <= n; ++i) scanf("%d", a + i);
for(int i = 1; i < n; ++i) add();
f[0] = total = n;
dfs_for_root(1, -1);
LL ans = dfs(root);
printf("%lld\n", (ans + ans) % mod);
return 0;
}