AcWing 252 树 (点分治)
题目链接:https://www.acwing.com/problem/content/description/254/
每次找到树的重心,分治下去统计答案(经过当前根节点的路径)即可
统计答案使用了指针扫描数组的方法,要注意去掉同一子树内路径的答案
还可以直接在树上统计子树答案(这个方法的好处是保证了分开的两段路径不在同一子树内),但是要使用平衡树,代码复杂度高
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<stack>
#include<queue>
using namespace std;
typedef long long ll;
const int maxn = 200010;
const int INF = 1000000007;
int N, K;
int mx, tot, root, ans, tail;
struct Node{
int u, bel, dis;
bool operator < (const Node &x) const{
return dis < x.dis;
}
}a[maxn];
void Add_Node(int u, int bel, int dis){ a[++tail].u = u, a[tail].bel = bel, a[tail].dis = dis; }
int h[maxn], cnt;
struct{
int to, next;
int cost;
}e[maxn << 1];
void add(int u, int v, int w){
e[++cnt].to = v;
e[cnt].cost = w;
e[cnt].next = h[u];
h[u] = cnt;
}
int sz[maxn], son[maxn], dis[maxn], vis[maxn];
int c[maxn]; // c[i] :属于 i 子树的节点有多少个
void Get_Root(int u, int par){ // 找重心
sz[u] = 1, son[u] = 0;
for(int i = h[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(vis[v] || v == par) continue;
Get_Root(v, u);
sz[u] += sz[v];
son[u] = max(son[u], sz[v]);
}
son[u] = max(son[u], tot - son[u]);
if(mx > son[u]){
root = u;
mx = son[u];
}
}
void Get_Dis(int u, int par, int rt){ // 统计距离
for(int i = h[u]; i != -1; i = e[i].next){
int v = e[i].to, w = e[i].cost;
if(vis[v] || v == par) continue;
dis[v] = dis[u] + w;
Add_Node(v, rt, dis[v]);
Get_Dis(v, u, rt);
}
}
void calc(int u){
dis[u] = 0; Add_Node(u, 0, 0);
for(int i = h[u] ; i != -1; i = e[i].next){
int v = e[i].to, w = e[i].cost;
if(vis[v]) continue;
dis[v] = dis[u] + w;
Add_Node(v, v, dis[v]);
Get_Dis(v, u, v);
}
sort(a + 1, a + 1 + tail); // 将子树节点按距离排序
for(int i = 1; i <= tail ; ++i ) ++c[a[i].bel]; // 统计每个子树内节点的数量
// 统计答案
int head;
for(head = 1; head < tail ; ++head){
--c[a[head].bel];
while(a[tail].dis + a[head].dis > K){
--c[a[tail].bel];
--tail;
if(tail == head) break; // 防止越界
}
ans += tail - head - c[a[head].bel];
}
for(int i = head; i <= tail ; ++i ) --c[a[i].bel];
}
void fenzhi(int u){
vis[u] = 1;
tail = 0;
calc(u);
for(int i = h[u]; i != -1; i = e[i].next){
int v = e[i].to , w = e[i].cost;
if(vis[v]) continue;
mx = INF, root = 0, tot = sz[v];
Get_Root(v,0);
fenzhi(root);
}
}
ll read(){ ll s=0,f=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') f=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ s=s*10+ch-'0'; ch=getchar(); } return s*f; }
int main(){
while(1){
memset(vis, 0, sizeof(vis));
memset(h, -1, sizeof(h)); cnt = 0; ans = 0;
N = read(), K = read();
if(!N && !K) break;
int u, v, w;
for(int i = 1; i < N; ++i){
u = read(), v = read(), w = read();
++u, ++v;
add(u, v, w), add(v, u, w);
}
mx = INF, tot = N;
Get_Root(1, 0);
fenzhi(root);
printf("%d\n",ans);
}
return 0;
}