题目链接: https://www.codechef.com/problems/TRIPS
感觉CC有点毒瘤啊。。
题解: 首先有一个性质可能是因为太傻所以网上没人解释,然而我看了半天: 就是正序和倒序经过同一段路径,用时一样。
我原来想了个很麻烦的证法,ckw: "显然把一个序列划分成数量尽可能少的子串,每一段和不超过\(P\), 那么从左往右和从右往左都是最优解,所以他俩相等啊"
发现了这个性质以及其一些简单的推论,后面的就比较简单了
分块讨论
对于\(p>\sqrt n\)的询问,暴力倍增跳即可,最多跳\(\sqrt n\)次。
对于\(p\le \sqrt n\)的询问,按\(p\)从小到大排序,只有\(O(\sqrt n)\)个不同的\(p\), 对于每一个预处理数组\(f[i][j]\)表示\(i\)点往上跳\(2^j\)天(不是步)跳到哪里,询问暴力跳即可
时间复杂度\(O(n\sqrt n\log n)\)
然后尽管复杂度这么大,在CC上测出来时间是\(3.42s\) (时限\(8s\)).
代码
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cassert>
#include<algorithm>
#include<cmath>
using namespace std;
const int N = 1e5;
const int MXB = 317;
const int LGN = 16;
struct Edge
{
int v,w,nxt;
} e[(N<<1)+3];
struct Query
{
int u,v,p,id;
bool operator <(const Query &arg) const
{
return p<arg.p;
}
} qr[N+3];
int fe[N+3];
int fa[N+3][LGN+3];
int dep[N+3],dis[N+3];
int f[N+3][LGN+3];
int ans[N+3];
int n,q,en,B;
void addedge(int u,int v,int w)
{
en++; e[en].v = v; e[en].w = w;
e[en].nxt = fe[u]; fe[u] = en;
}
void dfs(int u)
{\
for(int i=1; i<=LGN; i++) fa[u][i] = fa[fa[u][i-1]][i-1];
for(int i=fe[u]; i; i=e[i].nxt)
{
int v = e[i].v;
if(v==fa[u][0]) continue;
fa[v][0] = u;
dep[v] = dep[u]+1; dis[v] = dis[u]+e[i].w;
dfs(v);
}
}
int LCA(int u,int v)
{
if(dep[u]<dep[v]) {swap(u,v);}
int dif = dep[u]-dep[v];
for(int i=0; i<=LGN; i++) {if(dif&(1<<i)) u = fa[u][i];}
if(u==v) return u;
for(int i=LGN; i>=0; i--)
{
if(fa[u][i]!=fa[v][i]) {u = fa[u][i],v = fa[v][i];}
}
return fa[u][0];
}
int jump0(int u,int p)
{
int tdis = dis[u];
for(int i=LGN; i>=0; i--)
{
if(fa[u][i]!=0 && tdis-dis[fa[u][i]]<=p) {u = fa[u][i];}
}
return u;
}
int jump_large(int &u,int lca,int p)
{
int ret = 0;
while(dis[u]-dis[lca]>=p)
{
u = jump0(u,p);
ret++;
}
return ret;
}
int solve_large(int u,int v,int p)
{
int ret = 0; int lca = LCA(u,v);
int ret1 = jump_large(u,lca,p),ret2 = jump_large(v,lca,p);
ret = ret1+ret2;
// printf("u=%d v=%d\n",u,v);
if(u==lca && v==lca);
else if(dis[u]+dis[v]-2*dis[lca]<=p) ret++;
else ret+=2;
return ret;
}
int jump_small(int &u,int lca,int p)
{
int ret = 0;
for(int i=LGN; i>=0; i--)
{
if(dis[f[u][i]]>dis[lca]) {u = f[u][i]; ret += (1<<i);} //> not >=
}
return ret;
}
int solve_small(int u,int v,int p)
{
int ret = 0; int lca = LCA(u,v);
int ret1 = jump_small(u,lca,p),ret2 = jump_small(v,lca,p);
ret = ret1+ret2;
// printf("u=%d v=%d lca%d ret1=%d ret2=%d\n",u,v,lca,ret1,ret2);
if(u==lca && v==lca);
else if(dis[u]+dis[v]-2*dis[lca]<=p) ret++;
else ret+=2;
return ret;
}
int main()
{
scanf("%d",&n); B = sqrt(n);
for(int i=1; i<n; i++)
{
int x,y,z; scanf("%d%d%d",&x,&y,&z);
addedge(x,y,z); addedge(y,x,z);
}
dep[1] = dis[1] = 1; dfs(1);
scanf("%d",&q);
for(int i=1; i<=q; i++)
{
int u,v,p; scanf("%d%d%d",&qr[i].u,&qr[i].v,&qr[i].p); qr[i].id = i;
}
sort(qr+1,qr+q+1);
for(int i=1; i<=n; i++) f[i][0] = fa[i][0];
int id = 0;
for(int i=2; i<=B; i++)
{
for(int j=2; j<=n; j++)
{
if(dis[j]-dis[fa[f[j][0]][0]]<=i) {f[j][0] = fa[f[j][0]][0];}
}
for(int j=1; j<=LGN; j++)
{
for(int k=1; k<=n; k++) f[k][j] = f[f[k][j-1]][j-1];
}
while(id<q && qr[id+1].p==i)
{
id++;
ans[qr[id].id] = solve_small(qr[id].u,qr[id].v,qr[id].p);
}
}
for(id=id+1; id<=q; id++)
{
ans[qr[id].id] = solve_large(qr[id].u,qr[id].v,qr[id].p);
}
for(int i=1; i<=q; i++) printf("%d\n",ans[i]);
return 0;
}