P7283 [COCI2020-2021#4] Janjetina
显然点分治。
考虑设 \(Mx_i\) 为重心到 \(i\) 的链的最大值,\(L_i\) 为路径长度,那么条件变为:
\[\max(Mx_x, Mx_y)-L_x - k \ge L_y
\]
我们考虑直接枚举当前的 \(\max(Mx_x, Mx_y)=Mx_x\),那么需要保证之前加入的点的 \(Mx_y\) 要小于等于 \(Mx_x\),直接以 \(Mx\) 为关键字 sort 一遍,然后直接树状数组爆算。
但是这样显然会把同一个子树内的假点对给统计到,多做一遍减去即可。
复杂度 \(O(n\log^2 n)\)
#include <stdio.h>
#include <algorithm>
#include <string.h>
#include <cctype>
#include <vector>
using namespace std;
typedef long long ll;
typedef pair <int, int> Pii;
const int INF=0x3f3f3f3f;
const int mo=1e9+7;
inline int read(){
char ch=getchar();int x=0, f=1;
while(!isdigit(ch)){if(ch=='-') f=-1; ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x*f;
}
inline void write(int x){
if(x<0) putchar('-'), x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
inline int ksm(int a, int b){
int ret=1;
for(; b; b>>=1, a=1ll*a*a%mo)
if(b&1) ret=1ll*ret*a%mo;
return ret;
}
const int N=1e5+5;
const int M=1e5;
int idc, vis[N], siz[N], Mx[N], dep[N], mn, G, tot, h[N];
struct Edge{int to, nxt, w;}d[N*2];
void add(int x, int y, int w){d[++tot]=(Edge){y, h[x], w};h[x]=tot;}
#define lowbit(x) (x&(-x))
int n, bit[N*2], K;
void add(int x, int v){
++x;
for(; x<=M; x+=lowbit(x)) bit[x]+=v;
}
int query(int x){
if((++x)<=0) return 0;
int res=0;
for(; x; x-=lowbit(x)) res+=bit[x];
return res;
}
#undef lowbit
int sta[N], top;
bool cmp(int x, int y){
return Mx[x]<Mx[y];
}
void Getg(int x, int up){
vis[x]=idc, siz[x]=1;int mx=0;
for(int i=h[x], v; i; i=d[i].nxt)
if(vis[v=d[i].to]<idc&&vis[v]!=-1)
Getg(v, up), siz[x]+=siz[v], mx=max(mx, siz[v]);
mx=max(mx, up-siz[x]);
if(mx<mn) mn=mx, G=x;
}
long long ans=0;
void solve(int L, int R, int flg){
int ls=ans;
sort(sta+L, sta+R+1, cmp);
for(int i=L; i<=R; ++i)
ans+=flg*query(Mx[sta[i]]-dep[sta[i]]-K),
add(dep[sta[i]], 1);
for(int i=L; i<=R; ++i) add(dep[sta[i]], -1);
}
void calc(int x, int rt){
vis[x]=idc;sta[++top]=x, siz[x]=1;
for(int i=h[x], v; i; i=d[i].nxt)
if(vis[v=d[i].to]<idc&&vis[v]!=-1){
int lst=top;
Mx[v]=max(Mx[x], d[i].w), dep[v]=dep[x]+1,
calc(v, 0), siz[x]+=siz[v];
if(x!=rt) continue;
solve(lst+1, top, -1);
}
if(x==rt) solve(1, top, 1);
}
void dfz(int rt){
// printf("---%d\n", rt);
++idc, dep[rt]=0, Mx[rt]=0, top=0;
calc(rt, rt), vis[rt]=-1;
for(int i=h[rt], v; i; i=d[i].nxt)
if(vis[v=d[i].to]!=-1)
++idc, G=v, mn=siz[v],
Getg(v, siz[v]), dfz(G);
}
signed main(){
n=read(), K=read();
for(int i=1, x, y, z; i<n; ++i)
x=read(), y=read(), z=read(),
add(x, y, z), add(y, x, z);
++idc, G=1, mn=n, Getg(1, n), dfz(G);
ans=2ll*ans;
printf("%lld", ans);
return 0;
}