[UOJ388]配对树
题解
贪心+线段树
首先如果我们知道了哪些点是关键点应该怎么搞
显然最小的匹配方案所有的边至多被经过一次
可以考虑每条边的贡献
因为我们要贡献尽量小
所以我们尽量让每条边经过的人尽量少
那么每条边被经过的条件就是一条边连接的两个节点的子树内关键点数量是奇数
然后我们可以直接考虑每条边会被多少个区间影响
一条边能被一个区间影响的条件就是这个区间的大小是偶数且这个区间内在这个边的深度较深的端点的子树内的点的数量有奇数个
也就是满足\(j-i=0(\mod 2) s_j - s_i=1(\mod 2)\)
这个东西可以用线段树维护
就记录下标为\(奇数/偶数\)且前缀和为奇数的数量
那么修改就是修改从这个点到\(m\)都翻转,并且可以直接计算了
然后可以用线段树合并来处理这个东西
但是我写了半天写不动><
就直接上dsu on tree来搞子树了==
代码
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
# define ls (now << 1)
# define rs (now << 1 | 1)
const int M = 100005 ;
const int N = 20 ;
const int mod = 998244353 ;
using namespace std ;
inline int read() {
char c = getchar() ; int x = 0 , w = 1 ;
while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
return x*w ;
}
int n , m ;
int num , hea[M] ;
int cnt , ans , fdis[M] ;
int size[M] , bsn[M] ;
vector < int > vec[M] ;
struct E {
int nxt , to , dis ;
} edge[M * 2] ;
struct Node {
bool rev ;
int odd , even ;
} t[M * N] ;
inline void add_edge(int from , int to , int dis) {
edge[++num].nxt = hea[from] ;
edge[num].to = to ;
edge[num].dis = dis ;
hea[from] = num ;
}
inline void Flip(int now , int l , int r) {
t[now].even = r / 2 - (l - 1) / 2 - t[now].even ;
t[now].odd = (r + 1) / 2 - l / 2 - t[now].odd ;
t[now].rev ^= 1 ;
}
inline void pushup(int now) {
t[now].odd = t[ls].odd + t[rs].odd ;
t[now].even = t[ls].even + t[rs].even ;
}
inline void pushdown(int now , int l , int r) {
if(t[now].rev) {
t[now].rev = 0 ; int mid = (l + r) >> 1 ;
if(ls) Flip(ls , l , mid) ; if(rs) Flip(rs , mid + 1 , r) ;
}
}
void Change(int L , int R , int l , int r , int now) {
if(!now || l > R || r < L) return ;
if(l >= L && r <= R) { Flip(now , l , r) ; return ; }
pushdown(now , l , r) ;
int mid = (l + r) >> 1 ;
if(mid >= R) Change(L , R , l , mid , ls) ;
else if(mid < L) Change(L , R , mid + 1 , r , rs) ;
else { Change(L , mid , l , mid , ls) ; Change(mid + 1 , R , mid + 1 , r , rs) ; }
pushup(now) ;
}
inline void Calc(int u) {
int sum_e = m / 2 + 1 , sum_o = (m + 1) / 2 ;
ans = ((ans + 1LL * fdis[u] * t[1].even % mod * (sum_e - t[1].even) % mod) % mod + 1LL * fdis[u] * t[1].odd % mod * (sum_o - t[1].odd) % mod) % mod ;
}
void fdfs(int u , int father) {
size[u] = 1 ; int mx = -1 ;
for(int i = hea[u] ; i ; i = edge[i].nxt) {
int v = edge[i].to ; if(v == father) continue ;
fdis[v] = edge[i].dis ;
fdfs(v , u) ; size[u] += size[v] ;
if(size[v] > mx) mx = size[v] , bsn[u] = v ;
}
}
inline void Add(int u) {
for(int i = 0 , x , sz = vec[u].size() ; i < sz ; i ++) {
x = vec[u][i] ;
Change(x , m , 1 , m , 1) ;
}
}
void Update(int u , int father) {
Add(u) ;
for(int i = hea[u] ; i ; i = edge[i].nxt) {
int v = edge[i].to ;
if(v != father)
Update(v , u) ;
}
}
void dfs(int u , int father , bool Kp) {
for(int i = hea[u] ; i ; i = edge[i].nxt) {
int v = edge[i].to ;
if(v != father && v != bsn[u])
dfs(v , u , 0) ;
}
if(bsn[u]) dfs(bsn[u] , u , 1) ;
for(int i = hea[u] ; i ; i = edge[i].nxt) {
int v = edge[i].to ;
if(v != father && v != bsn[u])
Update(v , u) ;
}
Add(u) ;
if(fdis[u]) {
int sum_e = m / 2 + 1 , sum_o = (m + 1) / 2 ;
ans = (ans + 1LL * fdis[u] * t[1].even % mod * (sum_e - t[1].even) % mod) % mod ;
ans = (ans + 1LL * fdis[u] * t[1].odd % mod * (sum_o - t[1].odd) % mod) % mod ;
}
if(!Kp) Update(u , father) ;
}
int main() {
n = read() ; m = read() ;
for(int i = 1 , u , v , w ; i < n ; i ++) {
u = read() ; v = read() ; w = read() ;
add_edge(u , v , w) ; add_edge(v , u , w) ;
}
for(int i = 1 , x ; i <= m ; i ++) {
x = read() ;
vec[x].push_back(i) ;
}
fdfs(1 , 0) ;
dfs(1 , 0 , 1) ;
printf("%d\n",ans) ;
return 0 ;
}