CF161D Distance in Tree(点分治)
这是一道板子题
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int u[50010<<1],v[50010<<1],fir[50010],nxt[50010<<1],cnt,root,sz[50010],f[50010],middis[50010],midcnt,n,k,vis[50010],Siz;
long long ans=0;
void addedge(int ui,int vi){
++cnt;
u[cnt]=ui;
v[cnt]=vi;
nxt[cnt]=fir[ui];
fir[ui]=cnt;
}
void getroot(int u,int fa){
sz[u]=1,f[u]=1;
for(int i=fir[u];i;i=nxt[i]){
if(v[i]==fa||vis[v[i]])
continue;
getroot(v[i],u);
sz[u]+=sz[v[i]];
f[u]=max(f[u],sz[v[i]]);
}
f[u]=max(Siz-sz[u],f[u]);
if(f[u]<f[root])
root=u;
}
void getdis(int u,int d,int fa){
// printf("ux=%d\n",u);
middis[++midcnt]=d;
for(int i=fir[u];i;i=nxt[i]){
if(vis[v[i]]||v[i]==fa)
continue;
getdis(v[i],d+1,u);
}
}
int look1(int l,int k){
int ans=0,r=midcnt;
while(l<=r){
int mid=(l+r)>>1;
if(middis[mid]<k)
l=mid+1;
else
ans=mid,r=mid-1;
}
return ans;
}
int look2(int l,int k){
int ans=0,r=midcnt;
while(l<=r){
int mid=(l+r)>>1;
if(middis[mid]<=k)
l=mid+1,ans=mid;
else
r=mid-1;
}
return ans;
}
int solve(void){
sort(middis+1,middis+midcnt+1);
// for(int i=1;i<=midcnt;i++)
// printf("%d ",middis[i]);
// getchar();
// printf("\n");
int mid=0;
int l=1;
while(l<midcnt&&middis[l]+middis[midcnt]<k)
++l;
while(l<midcnt&&k-middis[l]>=middis[l]){
int l2=look2(l+1,k-middis[l]),l1=look1(l+1,k-middis[l]);
if(l2>=l1)
mid+=l2-l1+1;
l++;
}
return mid;
}
void divide(int u){
// printf("u=%d\n",u);
// getchar();
vis[u]=true;
midcnt=0;
getdis(u,0,0);
// printf("ok\n");
ans+=solve();
// printf("an=%d\n",ans);
for(int i=fir[u];i;i=nxt[i]){
if(vis[v[i]])
continue;
midcnt=0;
getdis(v[i],1,0);
ans-=solve();
// printf("s=%d\n",ans);
root=0;
Siz=sz[v[i]];
getroot(v[i],u);
divide(root);
}
}
int main(){
scanf("%d %d",&n,&k);
for(int i=1;i<=n-1;i++){
int a,b;
scanf("%d %d",&a,&b);
addedge(a,b);
addedge(b,a);
}
Siz=n;
f[0]=0x3f3f3f3f;
getroot(1,0);
divide(root);
printf("%lld",ans);
return 0;
}