hihoCoder #1065 全图传送
题意
给出一棵 \(N\)(\(N \le 10^5\))个点的树,有点权和边权。回答 \(q\)(\(q \le 10^5\)) 组询问:
(\(u, r\)):距离节点 \(u\) 不超过 \(r\) 的点中权值最大的点
输出点的编号,如有多解,输出最小编号。
Time Limit: 每个测试点 3s
做法
离线。树的点分治。
以树的重心为根,将无根树转化为有根树。
对于询问 (\(u,r\)),我们把「与 \(u\) 的距离不超过 \(r\) 的点」按「从 \(u\) 到该点是否要经过根节点」分成两类。
问题化为
求从 \(u\) 出发, 经根节点,移动距离不超过 \(r\) 所能到达的点中权值最大的那个点的编号。
-
用
std::map<long long, int> opt
维护 <到根节点的距离,节点编号>(key-value pair)。
opt[d]
表示到当前分治的根节点距离为d
的所有点中最优的那个点。
按 key 从小到大的顺序,对于相邻两 key-value pair 用前一 key 的 value 更新后一 key 的 value 。复杂度 \(O(n\log n)\)(\(n\) 是树的节点数,下同)
-
对于询问 (\(u,r\)),设 \(u\) 到根的距离为 \(d_u\),以 \(r-d_u\) 为参数,用
std::map::upper_bound()
查询,更新该询问的答案。复杂度 \(O(\sum\limits_{u\text{ in the tree}}|\\{(u,r)\\}| \times \log n)\) 。
总复杂度为 \(O((m+n)\log^2 n)\) 。
Implementation
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
vector<pair<int,int>> g[N], q[N];
int w[N];
bool used[N];
int size[N];
int tot;
pair<int,int> centroid(int u, int f){
size[u]=1;
int ma=0;
pair<int,int> res={INT_MAX, 0};
for(auto e: g[u]){
int v=e.first;
if(v!=f && !used[v]){
res=min(res, centroid(v, u));
size[u]+=size[v];
ma=max(ma, size[v]);
}
}
ma=max(ma, tot-size[u]);
res=min(res, {ma, u});
return res;
}
int better(int u, int v){
return w[u]>w[v] || (w[u]==w[v] && u<v) ? u: v;
}
map<long long, int> opt;
void dfs(int u, int f, long long d){
auto it=opt.find(d);
if(it==opt.end())
opt[d]=u;
else it->second=better(u, it->second);
for(auto e: g[u]){
int v=e.first, w=e.second;
if(v!=f && !used[v])
dfs(v, u, d+w);
}
}
int res[N];
void upd(int u, int f, long long d){
for(auto query: q[u]){
int r=query.first, id=query.second;
auto it=opt.upper_bound(r-d);
if(it!=opt.begin()){
res[id]=better(res[id], (--it)->second); //error-prone
}
}
for(auto e: g[u]){
int v=e.first, w=e.second;
if(v!=f && !used[v])
upd(v, u, d+w);
}
}
void DC(int u){
int root=centroid(u, u).second;
opt.clear(); // error-prone
dfs(root, root, 0);
for(auto it=opt.begin(); ;){
int tmp=it->second;
if(++it!=opt.end())
it->second=better(it->second, tmp);
else break;
}
upd(root, root, 0);
used[root]=true;
int _tot = tot;
for(auto e: g[root]){
int v=e.first;
if(!used[v]){
if(size[v] < size[root]) tot = v;
else tot = _tot - size[root];
DC(v);
}
}
int main(){
int n;
scanf("%d", &n);
for(int i=1; i<=n; i++)
scanf("%d", w+i);
for(int i=1; i<n; i++){
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
g[u].push_back({v, w});
g[v].push_back({u, w});
}
int m;
scanf("%d", &m);
for(int i=0; i<m; i++){
int u, r;
scanf("%d%d", &u, &r);
q[u].push_back({r, i});
}
tot = n;
DC(1);
for(int i=0; i<m; i++)
printf("%d\n", res[i]);
return 0;
}