#扫描线,线段树#nssl 1459 空间复杂度
分析
由于\(k\leq 10\)所以考虑用总方案减去经过两个差的绝对值\(\leq k\)的点的路径数
分类讨论一下发现要处理祖先关系和其它关系两种情况,考虑怎么去重,可以将这些答案看作一个个矩形,
然后就是要求矩形的面积并,用扫描线+线段树解决
代码
#include <cstdio>
#include <cctype>
#include <algorithm>
#include <vector>
#define rr register
using namespace std;
const int N=300011; typedef long long lll;
struct node{int y,next;}e[N<<1];
struct rec{
int x,l,r,type;
bool operator <(const rec &t)const{
return x<t.x;
}
}q[N*40];
int dfn[N],nfd[N],Tot,n,m; vector<int>K[N];
int ifn[N],as[N],tot,TOT,k,w[N<<2],lazy[N<<2]; lll ans;
inline signed iut(){
rr int ans=0; rr char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
inline void dfs(int x,int fa){
dfn[x]=++tot,nfd[tot]=x;
for (rr int i=as[x];i;i=e[i].next) if (e[i].y!=fa)
dfs(e[i].y,x),K[x].push_back(dfn[e[i].y]);
ifn[x]=tot;
}
inline void update(int k,int l,int r,int x,int y,int z){
if (l==x&&r==y){
lazy[k]+=z;
if (lazy[k]) w[k]=r-l+1;
else if (l==r) w[k]=0;
else w[k]=w[k<<1]+w[k<<1|1];
return;
}
rr int mid=(l+r)>>1;
if (y<=mid) update(k<<1,l,mid,x,y,z);
else if (x>mid) update(k<<1|1,mid+1,r,x,y,z);
else update(k<<1,l,mid,x,mid,z),update(k<<1|1,mid+1,r,mid+1,y,z);
if (lazy[k]) w[k]=r-l+1;
else w[k]=w[k<<1]+w[k<<1|1];
}
inline void add(int lx,int rx,int ly,int ry){
q[++Tot]={lx,ly,ry,1},q[++Tot]={rx+1,ly,ry,-1};
}
signed main(){
n=iut(),m=iut(),ans=1ll*n*(n+1)>>1;
for (rr int i=1;i<n;++i){
rr int x=iut(),y=iut();
e[++k]=(node){y,as[x]},as[x]=k;
e[++k]=(node){x,as[y]},as[y]=k;
}
dfs(1,0);
for (rr int i=1;i<=n;++i) sort(K[i].begin(),K[i].end());
for (rr int i=1;i<=n;++i)
for (rr int j=i+1;j<=i+m;++j){
if (j>n) break; rr int x=i,y=j;
if (dfn[x]>dfn[y]) x^=y,y^=x,x^=y;
if (dfn[x]<=dfn[y]&&dfn[y]<=ifn[x]){
int now=nfd[*--upper_bound(K[x].begin(),K[x].end(),dfn[y])];
if (dfn[now]>1) add(1,dfn[now]-1,dfn[y],ifn[y]);
if (ifn[now]<n) add(dfn[y],ifn[y],ifn[now]+1,n);
}else add(dfn[x],ifn[x],dfn[y],ifn[y]);
}
sort(q+1,q+1+Tot);
for (rr int i=1;i<Tot;++i)
update(1,1,n,q[i].l,q[i].r,q[i].type),ans-=1ll*(q[i+1].x-q[i].x)*w[1];
return !printf("%lld",ans);
}