题解[P7283 Janjetina]
题目链接
题意:求树上有多少对点 \((x,y)\) 满足其路径上边权最大值 \(-\) 路径长度 \(\geq\) 给定的 \(k\)。
\(n \leq 10^5\),边权最大值、 \(k\) \(\leq 10^6\)。
\(Solution:\)
具体思路类似我在 CF293E Close Vertices 中的题解
考虑点分治,设当前分治重心为 \(rt\),\(rt\) 的子树中每个点 \(x\) 到 \(rt\) 距离为 \(len[x]\) ,边权最大值为 \(mx[x]\)。
若 \(mx[x]-len[x]\geq k\) 直接算入贡献。
而处理 \(rt\) 的各个子树间的贡献时,
假设 \(x\) 是在当前处理的 \(rt\) 子树中,查询在之前已处理过的子树中的 \(y\) 有多少满足条件,
则 \(max(mx[x],mx[y])-(len[x]+len[y]) \geq k\)
于是先在之前处理的子树信息对 \(mx\) 排序,便可二分出第一个大于等于 \(mx[x]\) 的 \(mx\),设其位置为 \(pos\)
若对 \(mx[x]\) 排序,则 \(pos\) 可递增
若 \(mx[y]<mx[pos]\) : \(len[y]<=mx[x]-len[x]-k\),直接树状数组一遍插入,一遍查询
若 \(mx[y]>=mx[pos]\) : \(mx[y]-len[y]>=k+len[x]\),在之前存下 \(mx[y]-len[y]\),
两边乘 \(-1\) 变成 \(len[y]-mx[y]\leq -(k+len[x])\),查询时再加上一个大数即可。
由于是 \(mx[y]>=mx[pos]\) ,将 \(mx[x]\) 反过来使 \(pos\) 递减,再树状数组一边插入,一边查询。
因为题中 \((x,y)\) 与 \((y,x)\) 算两次,最终答案 \(\times 2\) 即可
时间复杂度:\(O(nlog^2 n)\)
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
//const int N=1e5+10;
const int N=2e5+10;
const int K=5e6+10;
int n,m,x,y,k,rt,nn,tot,v,cnt;
ll ans;
int to[N<<1],nextn[N<<1],h[N],w[N<<1];
#define lowbit(x) x&(-x)
#define max(a,b) a>b?a:b
struct BIT1{
ll t[K*3];
void update(int x,int v){
x+=K;
for(int i=x;i<=K<<1;i+=lowbit(i))t[i]+=v;
}
ll inquiry(int pos){
pos+=K;
ll res=0;
for(int i=pos;i;i-=lowbit(i))res+=t[i];
return res;
}
}t1;
struct BIT2{
ll t[K<<1];
void update(int x,int v){
x=K-x;
for(int i=x;i<=K<<1;i+=lowbit(i))t[i]+=v;
}
ll inquiry(int pos){
ll res=0;
pos=K-pos;
for(int i=pos;i;i-=lowbit(i))res+=t[i];
return res;
}
}t2;
struct stata{
ll dis;
ll maxn;
ll cnt;
bool operator <(const stata &x)const{
return x.maxn>maxn;
}
}tmp[N],q0[N],q[N];
void add(int x,int y,int v){
cnt++;
to[cnt]=y;
nextn[cnt]=h[x];
h[x]=cnt;
w[cnt]=v;
}
int size[N],mxsize[N];
bool b[N];
void findrt(int x,int anc){
size[x]=1,mxsize[x]=0;
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(b[y]||y==anc)continue;
findrt(y,x);
size[x]+=size[y];
mxsize[x]=max(mxsize[x],size[y]);
}
mxsize[x]=max(mxsize[x],nn-size[x]);
if(mxsize[x]<mxsize[rt])rt=x;
}
void dfs(int x,int anc,ll dis_,ll mx_){
tot++;
tmp[tot]=(stata){dis_,mx_,mx_-dis_};
if(mx_-dis_>=k)ans++;
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(b[y]||y==anc)continue;
dfs(y,x,dis_+1,max(mx_,w[i]));
}
}
void work(int x){
int tot1=0;
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(b[y])continue;
tot=0;
dfs(y,x,1,w[i]);
sort(tmp+1,tmp+tot+1);
int last=0;
for(int j=1;j<=tot;j++){
int pos=lower_bound(q+1,q+tot1+1,tmp[j])-q;
pos--;
for(int l=last+1;l<=pos;l++)t1.update(q[l].dis,1);
ans+=t1.inquiry(tmp[j].cnt-k);
last=pos;
}
for(int l=1;l<=last;l++)t1.update(q[l].dis,-1);
last=tot1+1;
for(int j=tot;j>0;j--){
int pos=lower_bound(q+1,q+tot1+1,tmp[j])-q;
for(int l=last-1;l>=pos;l--)t2.update(q[l].cnt,1);
ans+=t2.inquiry(k+tmp[j].dis);
last=pos;
}
for(int l=tot1;l>=last;l--)t2.update(q[l].cnt,-1);
merge(q+1,q+tot1+1,tmp+1,tmp+tot+1,q0+1);
tot1+=tot;
for(int j=1;j<=tot1;j++)q[j]=q0[j];
}
}
void solve(int x){
b[x]=1;
work(x);
for(int i=h[x];i;i=nextn[i]){
int y=to[i];
if(b[y])continue;
rt=0,mxsize[0]=n;
nn=size[y];
findrt(y,x);
solve(rt);
}
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&v);
add(x,y,v);
add(y,x,v);
}
rt=0,mxsize[0]=n;
nn=n;
findrt(1,0);
solve(rt);
printf("%lld",ans<<1);
}