题目链接: https://www.codechef.com/problems/TRIPS

感觉CC有点毒瘤啊。。

题解: 首先有一个性质可能是因为太傻所以网上没人解释,然而我看了半天: 就是正序和倒序经过同一段路径,用时一样。

我原来想了个很麻烦的证法,ckw: "显然把一个序列划分成数量尽可能少的子串,每一段和不超过\(P\), 那么从左往右和从右往左都是最优解,所以他俩相等啊"

发现了这个性质以及其一些简单的推论,后面的就比较简单了

分块讨论

对于\(p>\sqrt n\)的询问,暴力倍增跳即可,最多跳\(\sqrt n\)次。

对于\(p\le \sqrt n\)的询问,按\(p\)从小到大排序,只有\(O(\sqrt n)\)个不同的\(p\), 对于每一个预处理数组\(f[i][j]\)表示\(i\)点往上跳\(2^j\)天(不是步)跳到哪里,询问暴力跳即可

时间复杂度\(O(n\sqrt n\log n)\)

然后尽管复杂度这么大,在CC上测出来时间是\(3.42s\) (时限\(8s\)).

代码

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cassert>
#include<algorithm>
#include<cmath>
using namespace std;

const int N = 1e5;
const int MXB = 317;
const int LGN = 16;
struct Edge
{
	int v,w,nxt;
} e[(N<<1)+3];
struct Query
{
	int u,v,p,id;
	bool operator <(const Query &arg) const
	{
		return p<arg.p;
	}
} qr[N+3];
int fe[N+3];
int fa[N+3][LGN+3];
int dep[N+3],dis[N+3];
int f[N+3][LGN+3];
int ans[N+3];
int n,q,en,B;

void addedge(int u,int v,int w)
{
	en++; e[en].v = v; e[en].w = w;
	e[en].nxt = fe[u]; fe[u] = en;
}

void dfs(int u)
{\
	for(int i=1; i<=LGN; i++) fa[u][i] = fa[fa[u][i-1]][i-1];
	for(int i=fe[u]; i; i=e[i].nxt)
	{
		int v = e[i].v;
		if(v==fa[u][0]) continue;
		fa[v][0] = u;
		dep[v] = dep[u]+1; dis[v] = dis[u]+e[i].w;
		dfs(v);
	}
}

int LCA(int u,int v)
{
	if(dep[u]<dep[v]) {swap(u,v);}
	int dif = dep[u]-dep[v];
	for(int i=0; i<=LGN; i++) {if(dif&(1<<i)) u = fa[u][i];}
	if(u==v) return u;
	for(int i=LGN; i>=0; i--)
	{
		if(fa[u][i]!=fa[v][i]) {u = fa[u][i],v = fa[v][i];}
	}
	return fa[u][0];
}

int jump0(int u,int p)
{
	int tdis = dis[u];
	for(int i=LGN; i>=0; i--)
	{
		if(fa[u][i]!=0 && tdis-dis[fa[u][i]]<=p) {u = fa[u][i];}
	}
	return u;
}

int jump_large(int &u,int lca,int p)
{
	int ret = 0;
	while(dis[u]-dis[lca]>=p)
	{
		u = jump0(u,p);
		ret++;
	}
	return ret;
}

int solve_large(int u,int v,int p)
{
	int ret = 0; int lca = LCA(u,v);
	int ret1 = jump_large(u,lca,p),ret2 = jump_large(v,lca,p);
	ret = ret1+ret2;
//	printf("u=%d v=%d\n",u,v);
	if(u==lca && v==lca);
	else if(dis[u]+dis[v]-2*dis[lca]<=p) ret++;
	else ret+=2;
	return ret;
}

int jump_small(int &u,int lca,int p)
{
	int ret = 0;
	for(int i=LGN; i>=0; i--)
	{
		if(dis[f[u][i]]>dis[lca]) {u = f[u][i]; ret += (1<<i);} //> not >= 
	}
	return ret;
}

int solve_small(int u,int v,int p)
{
	int ret = 0; int lca = LCA(u,v);
	int ret1 = jump_small(u,lca,p),ret2 = jump_small(v,lca,p);
	ret = ret1+ret2;
//	printf("u=%d v=%d lca%d ret1=%d ret2=%d\n",u,v,lca,ret1,ret2);
	if(u==lca && v==lca);
	else if(dis[u]+dis[v]-2*dis[lca]<=p) ret++;
	else ret+=2;
	return ret;
}

int main()
{
	scanf("%d",&n); B = sqrt(n);
	for(int i=1; i<n; i++)
	{
		int x,y,z; scanf("%d%d%d",&x,&y,&z);
		addedge(x,y,z); addedge(y,x,z);
	}
	dep[1] = dis[1] = 1; dfs(1);
	scanf("%d",&q);
	for(int i=1; i<=q; i++)
	{
		int u,v,p; scanf("%d%d%d",&qr[i].u,&qr[i].v,&qr[i].p); qr[i].id = i;
	}
	sort(qr+1,qr+q+1);
	for(int i=1; i<=n; i++) f[i][0] = fa[i][0];
	int id = 0;
	for(int i=2; i<=B; i++)
	{
		for(int j=2; j<=n; j++)
		{
			if(dis[j]-dis[fa[f[j][0]][0]]<=i) {f[j][0] = fa[f[j][0]][0];}
		}
		for(int j=1; j<=LGN; j++)
		{
			for(int k=1; k<=n; k++) f[k][j] = f[f[k][j-1]][j-1];
		}
		while(id<q && qr[id+1].p==i)
		{
			id++;
			ans[qr[id].id] = solve_small(qr[id].u,qr[id].v,qr[id].p);
		}
	}
	for(id=id+1; id<=q; id++)
	{
		ans[qr[id].id] = solve_large(qr[id].u,qr[id].v,qr[id].p);
	}
	for(int i=1; i<=q; i++) printf("%d\n",ans[i]);
	return 0;
}