Educational Codeforces Round 2 E Lomsat gelral 线段树合并
大致题意
给你一棵有n个点的树,树上每个节点都有一种颜色ci(ci≤n),让你求每个点子树出现最多的颜色/编号的和
记性不好,主要是为了防止自己忘掉,今天和队友合作写题的时候甚至有点想不起来kmp是干嘛的,在博客园标记一下这个题目的解析和做法
关于什么时候能用线段树合并
假设我们会加入k个点,那么时间复杂度和空间复杂度都会是O(klogk)
做法和证明如下链接
线段树合并时间复杂度和空间复杂度证明
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define ls(x) (tr[x].l)
#define rs(x) (tr[x].r)
#define cnt(x) (tr[x].cnt)
#define mx(x) (tr[x].v)
#define se second
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
typedef std::pair<int, int > PII;
typedef std::pair<int, std::pair<int, int>> PIII;
typedef std::pair<ll, ll> Pll;
typedef std::pair<double, double> PDD;
using ld = double long;
const long double eps = 1e-9;
const int INF = 0x3f3f3f3f;
const int N = 2e5 + 10, M = 4e5 + 10;
int n , m, p;
int d1[] = {0, 0, 1, -1};
int d2[] = {1, -1, 0, 0};
int h[N], ne[N << 1], w[N << 1], e[N << 1], root[N];
int tot, idx;
ll ans[N];
struct node{
int l, r;
ll v;
int cnt;
}tr[N * 20];//最多添加1e5个点 每个点最多log层 所以乘以一个log的空间
void pushup(int u){
if(tr[ls(u)].cnt > tr[rs(u)].cnt){
tr[u].cnt = tr[ls(u)].cnt;
tr[u].v = tr[ls(u)].v;
}
else if(tr[ls(u)].cnt < tr[rs(u)].cnt){
tr[u].cnt = tr[rs(u)].cnt;
tr[u].v = tr[rs(u)].v;
}else{
tr[u].cnt = tr[ls(u)].cnt;
tr[u].v = tr[ls(u)].v + tr[rs(u)].v;
}
}
inline void add(int a, int b){
ne[idx] = h[a], e[idx] = b, h[a] = idx ++;
}
//线段树合并操作 当没有一个儿子的时候会O1否则logn
inline int merge(int p, int q, int L, int R){
if(!p) return q;
if(!q) return p;
if(L == R){
tr[p].cnt += tr[q].cnt;
tr[p].v = L;
return p;
}
int mid = L + R >> 1;
tr[p].l = merge(tr[p].l, tr[q].l, L, mid);
tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, R);
pushup(p);
return p;
}
inline void update(int u, int L, int R, int x, int sum){
if(L == R){
tr[u].cnt += sum;
tr[u].v = L;
return ;
}
int mid = L + R >> 1;
if(x <= mid){
if(!tr[u].l) tr[u].l = ++ tot;
update(tr[u].l, L, mid, x, sum);
}
else{
if(!tr[u].r) tr[u].r = ++ tot;
update(tr[u].r, mid + 1, R ,x, sum);
}
pushup(u);
}
inline void dfs(int u, int fa){
root[u] = ++ tot;
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(fa == j) continue;
dfs(j, u);
merge(root[u], root[j], 1, n);
}
update(root[u], 1, n, w[u], 1);
ans[u] = tr[root[u]].v;
}
inline void solve(){
std::cin >> n;
memset(h, -1, sizeof h);
for(int i = 1; i <= n; i ++) std::cin >> w[i];
for(int i = 1; i < n; i ++){
int a, b;
std::cin >> a >> b;
add(a, b), add(b, a);
}
dfs(1, -1);
for(int i = 1; i <= n; i ++) std::cout << ans[i] << ' ';
}
signed AC{
HYS
int _ = 1;
//std::cin >> _;
while(_ --)
solve();
return 0;
}