题解[CF293E Close Vertices]
题意:求树上满足 \(i\) 到 \(j\) 的路径长度不超过 \(kk\) ,且点权和不超过 \(k\) 的点对 \((i,j)\) 的个数
\(Solution\)
首先,这道题明显是点分治,且与这题十分相像,只不过多了路径长度这一维的限制
像那题一样,在点分治每层处理当前 \(root\) 时,考虑枚举每条从 \(root\) 出发的各个子树相互间产生的贡献,可避免容斥
若某子树中的点自身以满足长度 \(\leq kk\) ,且点权和 \(\leq k\) ,直接计入答案
然后对每个子树再将子树中的点到根的点权和排序,再统计与在它之前被遍历过的子树间的答案,便可写出一个朴素算法:
//以下中d为点到根的点权和,dd为点到根的路径长,q为目前统计到的d与dd
void work(int x){
cnt=0;
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(b[y])continue;
tot=0,dfs(y,x,w[i],1);//处理出子树中的tmp[].d、tmp[].dd
sort(tmp+1,tmp+tot+1);//对.d排序
for(int j=1;j<=tot;j++){
for(int i=1;i<=cnt&&q[i].d<=k-tmp[j].d;i++)if(q[i].dd<=kk-tmp[j].dd)ans++;//暴力
if(tmp[j].d<=k&&tmp[j].dd<=kk)ans++;
}
merge(tmp+1,tmp+tot+1,q+1,q+cnt+1,q0+1);//合并两段有序数组,不用sort
cnt+=tot;
for(int i=1;i<=cnt;i++)q[i]=q0[i];//保证q中.d有序
}
}
时间复杂度明显不对,考虑优化代码中暴力的那一段。
然而我们观察到由于 \(q\) 中 \(d\) 的递增, \(k-tmp[j].d\) 是可以用 \(\operatorname{lowerbound}\) 求出在 \(q\) 中的位置,记为 \(pos\) 。
那么,剩下的便是统计 \(q[1···pos]\) 中 \(q[l].dd\leq kk-tmp[j].dd\) 的 \(l\) 的个数。
而后,我们又能发现基于 \(tmp[j].d\) 的递增, \(pos\) 应该是递减的。
于是,我们将 \(tmp\) 反过来求, \(pos\) 便可递增,于是在查询 \(\leq kk-tmp[j].dd\) 的个数时就可以用树状数组,一边插入,一边查询。
时间复杂度: \(O(n\log^2 n)\)
总代码:( 当前在洛谷 \(rank2\) )
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5+10;
int n,m,nn,cnt,tot,kk,x,y,rt;
ll k,ww,ans;
int to[N],nextn[N],h[N];
int mxsize[N],size[N];
ll w[N];
struct dis{
ll d;
int dd;
bool operator <(const dis &x)const{return x.d>d;}
}tmp[N],q[N],q0[N];
int t[N];
bool b[N];
int lowbit(int x){return x&(-x);}
void update(int x,int v){
for(int i=x;i<=n;i+=lowbit(i))t[i]+=v;
}
ll inquiry(int x){
ll res=0;
for(int i=x;i;i-=lowbit(i))res+=t[i];
return res;
}
void add(int x,int y,ll ww){
cnt++;
to[cnt]=y;
nextn[cnt]=h[x];
h[x]=cnt;
w[cnt]=ww;
}
void findrt(int x,int anc){
mxsize[x]=0;
size[x]=1;
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(y==anc||b[y])continue;
findrt(y,x);
size[x]+=size[y];
mxsize[x]=max(mxsize[x],size[y]);
}
mxsize[x]=max(mxsize[x],nn-size[x]);
if(mxsize[x]<mxsize[rt])rt=x;
}
void dfs(int x,int anc,ll dis1,int dis_){
tmp[++tot]=(dis){dis1,dis_};
if(dis1<=k&&dis_<=kk)ans++;//直接计入答案
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(y==anc||b[y])continue;
dfs(y,x,dis1+w[i],dis_+1);
}
}
void work(int x){
cnt=0;
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(b[y])continue;
tot=0;
dfs(y,x,w[i],1);
sort(tmp+1,tmp+tot+1);
int last=0;
for(int j=tot;j>=1;j--)if(tmp[j].d<k&&tmp[j].dd<kk){//剪枝
int pos=lower_bound(q+1,q+cnt+1,dis{k-tmp[j].d+1,0})-q;//应是小于等于k-tmp[j].d的最大位置
while(q[pos].d>k-tmp[j].d)pos--;
for(int l=last+1;l<=pos&&l<=cnt;l++)update(q[l].dd,1);
ans+=inquiry(kk-tmp[j].dd);
last=pos;
}
for(int l=1;l<=last&&l<=cnt;l++)update(q[l].dd,-1);
merge(tmp+1,tmp+tot+1,q+1,q+cnt+1,q0+1);
cnt+=tot;
for(int i=1;i<=cnt;i++)q[i]=q0[i];
}
}
void solve(int x){
b[x]=1;
work(x);
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(b[y])continue;
rt=0,mxsize[0]=n;
nn=size[y];
findrt(y,x);
solve(rt);
}
}
main(){
scanf("%d%d%lld",&n,&kk,&k);
for(int i=1;i<n;i++){
scanf("%d%lld",&y,&ww);
add(i+1,y,ww),add(y,i+1,ww);
}
rt=0,nn=n;
mxsize[0]=n;
findrt(1,0);
solve(rt);
printf("%lld",ans);
}