题解 第零题

传送门

6:40看出一个树剖+链上倍增的做法,然后在有拍的情况下从7:40写到9:40勉强写完来不及调了就离谱

首先肯定会有一个错误的树剖思路:链上求和再除以 \(k\)
考虑如何正确地做这个树剖
那就要支持查询一段 \([L, R]\)\(L\) 时,\(k'\) 为给定值,到 \(R\) 会复活多少次及此时剩余的 \(k\)
而且注意到在第一次复活后就和初始 \(k'\) 无关了
所以可以在链上倍增出体力为 \(k'\) 时的下一个死亡点,然后再倍增跳死亡点
如果发现再跳一次死亡点就跳出当前链的范围了,就区间求和算出剩余体力
需要证个小结论:
对于一条链,从下往上走的死亡次数和从上往下走的死亡次数是一样的
考场上口胡的比较幼稚的证明(实质上好像是调整法):
如果把每个点看成一口井,\(k\) 的体力看成一段长为 \(k\) 的绳子
那我们从最下面一口井开始向上面能够到的最远的井扯绳子,多余的绳头垂到能够到的最远的井里
所需绳子个数为从下往上走的死亡次数
至于从上往下走,拉动每条绳子使上方的绳头在井口,下方的绳头垂到井里,即为一组合法方案
显然绳子条数不变,得证

然后除了树剖+链上倍增有亿点点难写之外就没什么了,复杂度 \(O(nlog^2n)\) 不过常数很小

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define ll long long
#define reg register int
#define fir first
#define sec second
#define make make_pair
//#define int long long 

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k, q;
int head[N], size;
struct edge{int to, next, val;}e[N<<1];
inline void add(int s, int t, int w) {e[++size].to=t; e[size].val=w; e[size].next=head[s]; head[s]=size;}

namespace force{
	int dfs(int u, int to, int fa, int rest, int cnt) {
		if (rest<=0) rest=k, ++cnt;
		if (u==to) return cnt;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v=e[i].to;
			if (v==fa) continue;
			int tem=dfs(v, to, u, rest-e[i].val, cnt);
			if (~tem) return tem;
		}
		return -1;
	}
	void solve() {
		q=read();
		for (int i=1,s,t; i<=q; ++i) {
			s=read(); t=read();
			printf("%d\n", dfs(s, t, 0, k, 0));
		}
		exit(0);
	}
}

namespace task1{
	int id[N], rk[N], val[N], tot, nxt[N], st[N][30], scnt[N][30], dep[N], lg[N];
	int tl[N<<2], tr[N<<2]; ll dat[N<<2], sum[N];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define dat(p) dat[p]
	#define pushup(p) dat(p)=dat(p<<1)+dat(p<<1|1)
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r;
		if (l==r) {dat(p)=val[rk[l]]; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	ll qsum(int p, int l, int r) {
		if (l<=tl(p) && r>=tr(p)) return dat(p);
		int mid=(tl(p)+tr(p))>>1; ll ans=0;
		if (l<=mid) ans+=qsum(p<<1, l, r);
		if (r>mid) ans+=qsum(p<<1|1, l, r);
		return ans;
	}
	void dfs(int u, int fa, int in_val) {
		id[u]=++tot;
		rk[tot]=u;
		val[u]=in_val;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v=e[i].to;
			if (v==fa) continue;
			dfs(v, u, e[i].val);
		}
	}
	int qpos(int ql, int r, ll rest) {
		//cout<<"qpos "<<ql<<' '<<r<<' '<<rest<<endl;
		int l=ql, mid;
		while (l<=r) {
			mid=(l+r)>>1;
			//cout<<"mid: "<<l<<' '<<r<<' '<<mid<<endl;
			//cout<<"cmp: "<<sum[mid]-sum[ql]<<' '<<rest<<endl;
			if (sum[mid]-sum[ql]<rest) l=mid+1;
			else r=mid-1;
		}
		//cout<<endl;
		return l;
	}
	void solve() {
		dfs(1, 0, 0);
		build(1, 1, n);
		for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
		for (int i=1; i<=n; ++i) sum[i]=sum[i-1]+val[i];
		for (int i=n; i; --i) {
			nxt[i]=qpos(i, n, k);
			if (nxt[i]>=n+1) dep[i]=1;
			else dep[i]=dep[nxt[i]]+1, st[i][0]=nxt[i]; //, scnt[i][0]=1;
		}
		for (int i=n; i; --i) 
			for (int j=1; j<30; ++j)
				if (dep[i]>=1<<j) st[i][j]=st[st[i][j-1]][j-1];
				else break;
		//cout<<"sum: "; for (int i=1; i<=n; ++i) cout<<sum[i]<<' '; cout<<endl;
		//cout<<"nxt: "; for (int i=1; i<=n; ++i) cout<<nxt[i]<<' '; cout<<endl;
		//cout<<"dep: "; for (int i=1; i<=n; ++i) cout<<dep[i]<<' '; cout<<endl;
		
		q=read();
		for (int i=1,s,t,ans; i<=q; ++i) {
			s=read(); t=read(); ans=1;
			if (s>t) swap(s, t);
			if (qsum(1, s+1, t)<k) {puts("0"); continue;}
			
			s=qpos(s, t, k);
			//cout<<"s: "<<s<<endl;
			for (int j=lg[dep[s]-1]-1; ~j; --j) if (st[s][j]) {
				//cout<<"jump: "<<st[s][j]<<endl;
				if (st[s][j]<=t) s=st[s][j], ans+=1<<j;
			}
			//cout<<"s2: "<<s<<endl;
			
			printf("%d\n", ans);
		}
		exit(0);
	}
}

namespace task2{
	int id[N], rk[N], val[N], tot, nxt[N], st[N][30], scnt[N][30], dep[N], lg[N], top[N], siz[N], msiz[N], mson[N], tdep[N], fa[N], msonv[N];
	int sta[N], stop;
	int tl[N<<2], tr[N<<2]; ll dat[N<<2], sum[N];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define dat(p) dat[p]
	#define pushup(p) dat(p)=dat(p<<1)+dat(p<<1|1)
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r;
		if (l==r) {dat(p)=val[l]; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	ll qsum(int p, int l, int r) {
		if (l<=tl(p) && r>=tr(p)) return dat(p);
		int mid=(tl(p)+tr(p))>>1; ll ans=0;
		if (l<=mid) ans+=qsum(p<<1, l, r);
		if (r>mid) ans+=qsum(p<<1|1, l, r);
		return ans;
	}
	void dfs1(int u, int pa) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v=e[i].to;
			if (v==pa) continue;
			fa[v]=u, tdep[v]=tdep[u]+1, dfs1(v, u);
			siz[u]+=siz[v];
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v, msonv[u]=e[i].val;
		}
	}
	void dfs2(int u, int f, int t, int in_val) {
		if (u==t) sta[++stop]=u;
		top[u]=t;
		id[u]=++tot;
		rk[tot]=u;
		val[id[u]]=in_val;
		if (!mson[u]) return ;
		dfs2(mson[u], u, t, msonv[u]);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v=e[i].to;
			if (v==f || v==mson[u]) continue;
			dfs2(v, u, v, e[i].val);
		}
	}
	int qpos(int l, int qr, ll rest) {
		//cout<<"qpos "<<ql<<' '<<r<<' '<<rest<<endl;
		int r=qr, mid;
		while (l<=r) {
			mid=(l+r)>>1;
			//cout<<"mid: "<<l<<' '<<r<<' '<<mid<<endl;
			//cout<<"cmp: "<<sum[mid]-sum[ql]<<' '<<rest<<endl;
			if (sum[qr]-sum[mid-1]>=rest) l=mid+1;
			else r=mid-1;
		}
		//cout<<endl;
		return l-2;
	}
	int lca(int a, int b) {
		while (top[a]!=top[b]) {
			if (tdep[top[a]]<tdep[top[b]]) swap(a, b);
			a=fa[top[a]];
		}
		if (tdep[a]>tdep[b]) swap(a, b);
		return a;
	}
	pair<int, int> query(int a, int g) {
		//cout<<"query "<<a<<' '<<g<<endl;
		ll ks=k; int ans=0;
		while (top[a]!=top[g]) {
			int t=qpos(0, id[a], ks);
			//cout<<"t: "<<t<<' '<<id[top[a]]<<endl;
			if (t>=id[top[a]]) {
				a=rk[t], ks=k, ++ans;
				//cout<<"pos4 "<<a<<' '<<st[a][0]<<endl;
				//cout<<"top: "<<top[a]<<endl;
				for (int i=lg[dep[id[a]]]-1; ~i; --i)
					if (st[id[a]][i]>=id[top[a]]) a=rk[st[id[a]][i]], ans+=1<<i; //cout<<"add: "<<st[id[a]][i]<<' '<<(1<<i)<<endl, 
			}
			if (a!=top[a]) {
				ks-=qsum(1, id[top[a]]+1, id[a]); //, cout<<"-= "<<id[top[a]]<<' '<<id[a]<<' '<<qsum(1, id[top[a]], id[a])<<endl;
				if (ks<=0) ks=k, ++ans;
			}
			ks-=val[id[top[a]]];
			if (ks<=0) ks=k, ++ans;
			a=fa[top[a]];
			//cout<<"first jump: "<<a<<' '<<ks<<endl;
		}
		//cout<<"pos3: "<<a<<' '<<g<<' '<<ans<<endl;
		if (a!=g) {
			int t=qpos(0, id[a], ks);
			//cout<<"t: "<<rk[t]<<endl;
			if (t>=id[g]) {
				//cout<<"pos5"<<endl;
				a=rk[t], ks=k, ++ans;
				for (int i=lg[dep[id[a]]]-1; ~i; --i)
					if (st[id[a]][i]>=id[g]) a=rk[st[id[a]][i]], ans+=1<<i;
			}
			ks-=qsum(1, id[g]+1, id[a]);
			if (ks<=0) ks=k, ++ans;
		}
		//cout<<"return"<<endl<<endl;
		return make(ans, ks);
	}
	void solve() {
		tdep[1]=1; dfs1(1, 0); dfs2(1, 0, 1, 0);
		build(1, 1, n);
		for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
		for (int i=1; i<=n; ++i) sum[i]=sum[i-1]+val[i];
		for (int i=1; i<=n; ++i) {
			nxt[i]=qpos(1, i, k);
			if (nxt[i]<=0) dep[i]=1;
			else dep[i]=dep[nxt[i]]+1, st[i][0]=nxt[i]; //, scnt[i][0]=1;
		}
		for (int i=1; i<=n; ++i) 
			for (int j=1; j<30; ++j)
				if (dep[i]>=1<<j) st[i][j]=st[st[i][j-1]][j-1];
				else break;
		#if 0
		cout<<"val: "; for (int i=1; i<=n; ++i) cout<<val[i]<<' '; cout<<endl;
		cout<<"sum: "; for (int i=1; i<=n; ++i) cout<<sum[i]<<' '; cout<<endl;
		cout<<"nxt: "; for (int i=1; i<=n; ++i) cout<<nxt[i]<<' '; cout<<endl;
		cout<<"dep: "; for (int i=1; i<=n; ++i) cout<<dep[i]<<' '; cout<<endl;
		cout<<"id: "; for (int i=1; i<=n; ++i) cout<<id[i]<<' '; cout<<endl;
		cout<<"top: "; for (int i=1; i<=n; ++i) cout<<top[i]<<' '; cout<<endl;
		#endif
		
		q=read();
		pair<int, int> t1, t2;
		for (int i=1,s,t,g; i<=q; ++i) {
			s=read(); t=read(); g=lca(s, t);
			//cout<<"lca: "<<lca(s, t)<<endl;
			if (s==g || t==g) {
				//cout<<"pos1"<<endl;
				if (t!=g) swap(s, t);
				printf("%d\n", query(s, t).fir);
			}
			else {
				//cout<<"pos2"<<endl;
				t1=query(s, g), t2=query(t, g);
				//cout<<"get_ans: "<<t1.fir<<','<<t1.sec<<' '<<t2.fir<<','<<t2.sec<<endl;
				printf("%d\n", t1.fir+t2.fir+(t1.sec+t2.sec<=k));
			}
			//cout<<endl;
		}
		exit(0);
	}
}

signed main()
{
	memset(head, -1, sizeof(head));
	n=read(); k=read();
	for (int i=1,u,v,w; i<n; ++i) {
		u=read(); v=read(); w=read();
		add(u, v, w); add(v, u, w);
	}
	//force::solve();
	task2::solve();
	
	return 0;
}
posted @ 2021-09-11 06:34  Administrator-09  阅读(5)  评论(0编辑  收藏  举报