题解[P7283 Janjetina]

题目链接
题意:求树上有多少对点 \((x,y)\) 满足其路径上边权最大值 \(-\) 路径长度 \(\geq\) 给定的 \(k\)
\(n \leq 10^5\),边权最大值、 \(k\) \(\leq 10^6\)

\(Solution:\)

具体思路类似我在 CF293E Close Vertices 中的题解
考虑点分治,设当前分治重心为 \(rt\)\(rt\) 的子树中每个点 \(x\)\(rt\) 距离为 \(len[x]\) ,边权最大值为 \(mx[x]\)
\(mx[x]-len[x]\geq k\) 直接算入贡献。
而处理 \(rt\) 的各个子树间的贡献时,
假设 \(x\) 是在当前处理的 \(rt\) 子树中,查询在之前已处理过的子树中的 \(y\) 有多少满足条件,
\(max(mx[x],mx[y])-(len[x]+len[y]) \geq k\)
于是先在之前处理的子树信息对 \(mx\) 排序,便可二分出第一个大于等于 \(mx[x]\)\(mx\),设其位置为 \(pos\)
若对 \(mx[x]\) 排序,则 \(pos\) 可递增

\(mx[y]<mx[pos]\) : \(len[y]<=mx[x]-len[x]-k\),直接树状数组一遍插入,一遍查询

\(mx[y]>=mx[pos]\) : \(mx[y]-len[y]>=k+len[x]\),在之前存下 \(mx[y]-len[y]\)
两边乘 \(-1\) 变成 \(len[y]-mx[y]\leq -(k+len[x])\),查询时再加上一个大数即可。
由于是 \(mx[y]>=mx[pos]\) ,将 \(mx[x]\) 反过来使 \(pos\) 递减,再树状数组一边插入,一边查询。

因为题中 \((x,y)\)\((y,x)\) 算两次,最终答案 \(\times 2\) 即可
时间复杂度:\(O(nlog^2 n)\)
代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
//const int N=1e5+10;
const int N=2e5+10;
const int K=5e6+10;
int n,m,x,y,k,rt,nn,tot,v,cnt;
ll ans;
int to[N<<1],nextn[N<<1],h[N],w[N<<1];
#define lowbit(x) x&(-x)
#define max(a,b) a>b?a:b
struct BIT1{
	ll t[K*3];
	void update(int x,int v){
		x+=K;
		for(int i=x;i<=K<<1;i+=lowbit(i))t[i]+=v;
	}
	ll inquiry(int pos){
		pos+=K;
		ll res=0;
		for(int i=pos;i;i-=lowbit(i))res+=t[i];
		return res;
	}
}t1;
struct BIT2{
	ll t[K<<1];
	void update(int x,int v){
		x=K-x;
		for(int i=x;i<=K<<1;i+=lowbit(i))t[i]+=v;
	}
	ll inquiry(int pos){
		ll res=0;
		pos=K-pos;
		for(int i=pos;i;i-=lowbit(i))res+=t[i];
		return res;
	}
}t2;
struct stata{
	ll dis;
	ll maxn;
	ll cnt;
	bool operator <(const stata &x)const{
		return x.maxn>maxn;
	}
}tmp[N],q0[N],q[N];
void add(int x,int y,int v){
	cnt++;
	to[cnt]=y;
	nextn[cnt]=h[x];
	h[x]=cnt; 
	w[cnt]=v;
}
int size[N],mxsize[N];
bool b[N];
void findrt(int x,int anc){
	size[x]=1,mxsize[x]=0;
	for(int i=h[x];i;i=nextn[i]){
		int y=to[i];
		if(b[y]||y==anc)continue;
		findrt(y,x);
		size[x]+=size[y];
		mxsize[x]=max(mxsize[x],size[y]);
	}
	mxsize[x]=max(mxsize[x],nn-size[x]);
	if(mxsize[x]<mxsize[rt])rt=x;
}
void dfs(int x,int anc,ll dis_,ll mx_){
	tot++;
	tmp[tot]=(stata){dis_,mx_,mx_-dis_};
	if(mx_-dis_>=k)ans++;
	for(int i=h[x];i;i=nextn[i]){
		int y=to[i];
		if(b[y]||y==anc)continue;
		dfs(y,x,dis_+1,max(mx_,w[i]));
	}
}
void work(int x){
	int tot1=0;
	for(int i=h[x];i;i=nextn[i]){
		int y=to[i];
		if(b[y])continue;
		tot=0;
		dfs(y,x,1,w[i]);
		sort(tmp+1,tmp+tot+1);
		int last=0;
		for(int j=1;j<=tot;j++){
			int pos=lower_bound(q+1,q+tot1+1,tmp[j])-q;
			pos--;
			for(int l=last+1;l<=pos;l++)t1.update(q[l].dis,1);
			ans+=t1.inquiry(tmp[j].cnt-k);
			last=pos;
		}
		for(int l=1;l<=last;l++)t1.update(q[l].dis,-1);
		last=tot1+1;
		for(int j=tot;j>0;j--){
			int pos=lower_bound(q+1,q+tot1+1,tmp[j])-q;
			for(int l=last-1;l>=pos;l--)t2.update(q[l].cnt,1);
			ans+=t2.inquiry(k+tmp[j].dis);
			last=pos;
		}
		for(int l=tot1;l>=last;l--)t2.update(q[l].cnt,-1);
		merge(q+1,q+tot1+1,tmp+1,tmp+tot+1,q0+1);
		tot1+=tot;
		for(int j=1;j<=tot1;j++)q[j]=q0[j];
	}
}
void solve(int x){
	b[x]=1;
	work(x);
	for(int i=h[x];i;i=nextn[i]){
		int y=to[i];
		if(b[y])continue;
		rt=0,mxsize[0]=n;
		nn=size[y];
		findrt(y,x);
		solve(rt);
	}
}
int main(){
	scanf("%d%d",&n,&k);
	for(int i=1;i<n;i++){
		scanf("%d%d%d",&x,&y,&v);
		add(x,y,v);
		add(y,x,v);
	}
	rt=0,mxsize[0]=n;
	nn=n;
	findrt(1,0);
	solve(rt);
	printf("%lld",ans<<1);
}
posted @ 2021-04-04 10:54  Y_B_X  阅读(49)  评论(0编辑  收藏  举报