点分治模板

题目链接:https://www.luogu.com.cn/problem/P3806

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 10010;
const int INF = 1000000007;

int n, m;
int q[105];

int h[maxn], cnt = 0;
struct E{
	int to, next, cost;
}e[maxn<<1];
void add(int u, int v, int w){
	e[++cnt].to = v;
	e[cnt].cost = w;
	e[cnt].next = h[u];
	h[u] = cnt;
} 

int rt, tot;
int tag[105];
int sz[maxn], son[maxn], vis[maxn], maxson;

void getrt(int u, int par){
	sz[u] = 1; son[u] = 0;
	for(int i = h[u] ; i != -1 ; i = e[i].next){
		int v = e[i].to;
		if(vis[v] || v == par) continue;
		getrt(v, u);
		sz[u] += sz[v];
		son[u] = max(son[u], sz[v]);
	}
	son[u] = max(son[u], tot - sz[u]);
	if(son[u] < maxson){
		maxson = son[u];
		rt = u;
	}
}

int dis[maxn], d[maxn], has[10000010], dnum;
void getdis(int u, int par){
	if(dis[u] <= 10000000) 	d[++dnum] = dis[u]; 
	for(int i = h[u] ; i != -1 ; i = e[i].next){
		int v = e[i].to, w = e[i].cost;
		if(vis[v] || v == par) continue;
		dis[v] = dis[u] + w;
		getdis(v, u);
	}
}

int alld[maxn], dtot = 0;
void calc(int u){
	dtot = 0;
	dis[u] = 0; has[0] = 1;
	for(int i = h[u] ; i != -1 ; i = e[i].next){
		int v = e[i].to;
		if(vis[v]) continue;
		dis[v] = e[i].cost;
		dnum = 0;
		getdis(v, u);
		for(int j = 1 ; j <= dnum ; ++j){
			for(int k = 1 ; k <= m ; ++k){
				if(q[k] - d[j] >= 0) {
					tag[k] |= has[q[k]-d[j]];
				}
			}
		}
		
		for(int j = 1 ; j <= dnum ; ++j){
			alld[++dtot] = d[j];
			has[d[j]] = 1;
		}
	}
	for(int i = 1 ; i <= dtot ; ++i) has[alld[i]] = 0; has[0] = 0;
}

void divi(int u){
	calc(u);
	vis[u] = 1;
	
	for(int i = h[u] ; i != -1 ; i = e[i].next){
		int v = e[i].to;
		if(vis[v]) continue;
		maxson = INF, tot = sz[v];
		getrt(v, u);
		divi(rt);
	}
}

void solve(){
	maxson = INF, tot = n;
	getrt(1, 0);
	divi(rt);
}

ll read(){ ll s = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar(); } while(ch >= '0' && ch <= '9'){ s = s * 10 + ch - '0'; ch = getchar(); } return s * f; }

int main(){
	memset(h, -1, sizeof(h));
	n = read(), m = read();
	int u, v, w;
	for(int i = 1 ; i < n ; ++i){
		u = read(), v = read(), w = read();
		add(u, v, w); add(v, u, w);
	}
	
	for(int i = 1 ; i <= m ; ++i){
		q[i] = read();
	}
	
	solve();
	
	for(int i = 1 ; i <= m ; ++i){
		if(tag[i]) printf("AYE\n");
		else printf("NAY\n");
	} 
	
	return 0;
}
posted @ 2021-09-18 10:27  Tartarus_li  阅读(20)  评论(0编辑  收藏  举报