【题解】 「NOI2014」购票 dp+斜率优化+点分治 LOJ2249
Legend
Link \(\textrm{to LOJ}\)。
给定一棵内向树,每个结点(除了根)有如下五个信息:
- 父亲结点 \(f_i\);
- 到父亲的距离 \(s_i\);
- 起步价格 \(q_i\),表示在该节点乘坐交通工具的起步价格;
- 单位路程价格 \(p_i\),表示在该节点乘坐交通工具的单位路程价格;
- 最大路程限制 \(l_i\),超过这个距离的祖先结点不可以用该点的交通工具一次性到达;
其中每次乘车前都把必须指定终点并不能中途下车。
求所有结点到根的最短路。
数据范围懒得写了。
时空:\(\rm{3s/513MiB}\)。
Editorial
感觉是一道斜率优化强行上树的题目。
事实就是这样,我们可以很快推出来序列上的式子。
\(dp_i= \min\limits _{j=1}^{i-1} dp_j + (S_i - S_j)p_i + q_i\),其中 \(S_i\) 表示结点 \(i\) 到根的距离。
\(dp_j = p_i S_j + dp_i - S_ip_i - q_i\),转移点位置 \((S_j,dp_j)\),斜率 \(p_i\),最小化截距。
哼哼,树上斜率优化?不就是回溯的时候重置一下被修改的位置吗?
哼哼,\(p_i\) 没有单调性?就二分一下。
你激动地打完代码交上去发现只有 \(50\) 分,仔细一看,发现方程忘记了 \((S_i - S_j \le l_i)\) 的限制。
然后你就发现这个东西非常不好维护 >_<,穷途末路了吗?
不……序列上这样子的问题有个解决方法是 \(\rm{CDQ}\) 分治。
把序列分成前后两部分,先计算左侧的 \(dp\) 数组,再统计跨越左右的,再递归右边。
统计跨越左右的时候,要按照 \(l_i\) 从右到左排序,就可以照样维护凸包了(只不过插入顺序反过来了而已)。
来到树上的话,就直接换成点分治就好了。
实现上的细节是当前的分治中心 \(x\) 并不会被计算到上半部分的子树(对于序列就是左侧),所以要进行暴力更新。
总复杂度就是 \(O(n \log^2 n)\)。
Code
注意到这题如果用叉积判凸包就会爆 \(\textrm{long long}\)。
于是我直接莽了一发用斜率,没想到过了。
#include <bits/stdc++.h>
#define debug(...) ;//fprintf(stderr ,__VA_ARGS__)
#define LL long long
#define __FILE(x)\
freopen(#x".in" ,"r" ,stdin);\
freopen(#x".out" ,"w" ,stdout)
using namespace std;
const int MX = 2e5 + 233;
LL read(){
char k = getchar(); LL x = 0;
while(k < '0' || k > '9') k = getchar();
while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
return x;
}
int head[MX] ,tot;
struct edge{
int node ,next;
LL w;
}h[MX << 1];
void addedge(int u ,int v ,LL w ,int flg = 1){
h[++tot] = (edge){v ,head[u] ,w} ,head[u] = tot;
if(flg) addedge(v ,u ,w ,false);
}
int n ,t;
int fa[MX];
LL S[MX] ,p[MX] ,q[MX] ,lim[MX];
void getS(int x){
// debug("visiting %d\n" ,x);
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == fa[x]) continue;
S[d] = S[x] + h[i].w;
getS(d);
}
}
int R ,sz[MX] ,mxsz[MX] ,subsize ,vis[MX];
void getGra(int x ,int f){
sz[x] = 1 ,mxsz[x] = 0;
for(int i = head[x] ,d ; i ; i = h[i].next){
if(vis[d = h[i].node] || d == f) continue;
getGra(d ,x);
sz[x] += sz[d];
mxsz[x] = max(mxsz[x] ,sz[d]);
}
mxsz[x] = max(mxsz[x] ,subsize - sz[x]);
if(mxsz[x] < mxsz[R]) R = x;
}
LL dp[MX];
void doit(int x ,int top);
void solve(int x){
debug("SOLVE %d\n" ,x);
vis[x] = 1;
int upper = fa[x];
while(!vis[upper]) upper = fa[upper];
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == fa[x]){
if(!vis[d]){
mxsz[R = 0] = subsize = sz[d];
getGra(d ,x);
solve(R);
}
break;
}
}
for(int now = fa[x] ; now != upper && S[x] - S[now] <= lim[x] ; now = fa[now]){
dp[x] = min(dp[x] ,dp[now] + (S[x] - S[now]) * p[x] + q[x]);
}
// debug("solve %d\n" ,x);
doit(x ,upper);
for(int i = head[x] ,d ; i ; i = h[i].next){
if(vis[d = h[i].node] || d == fa[x]) continue;
mxsz[R = 0] = subsize = sz[d];
getGra(d ,x);
solve(R);
}
}
int down[MX] ,dcnt;
bool cmp(int a ,int b){return S[a] - lim[a] > S[b] - lim[b];}
void getDown(int x ,int f ,LL dist){
if(dist <= lim[x]) down[++dcnt] = x;
for(int i = head[x] ,d ; i ; i = h[i].next){
if(vis[d = h[i].node] || d == f) continue;
getDown(d ,x ,dist + h[i].w);
}
}
int que[MX] ,TAIL;
double slope(int j1 ,int j2){
return 1.0 * (dp[j1] - dp[j2]) / (S[j1] - S[j2]);
}
int search(int l ,int r ,int x){
++l;
int mid;
while(l <= r){
mid = (l + r) >> 1;
int j1 = que[mid - 1] ,j2 = que[mid];
if(slope(j1 ,j2) > p[x]){
l = mid + 1;
}
else{
r = mid - 1;
}
}
return l - 1;
}
void doit(int x ,int top){
dcnt = 0;
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == fa[x]) continue;
getDown(d ,x ,h[i].w);
}
std::sort(down + 1 ,down + 1 + dcnt ,cmp);
// 越深的排序与越靠前
TAIL = 0;
que[++TAIL] = x;
int trcnt = fa[x];
for(int dd = 1 ; dd <= dcnt ; ++dd){
int now = down[dd];
while(trcnt != top && S[now] - S[trcnt] <= lim[now]){
while(1 < TAIL && slope(trcnt ,que[TAIL]) >= slope(que[TAIL] ,que[TAIL - 1])){
--TAIL;
}
que[++TAIL] = trcnt;
trcnt = fa[trcnt];
}
int tr = que[search(1 ,TAIL ,now)];
// printf("%d tr from %d\n" ,now ,tr);
dp[now] = min(dp[now] ,dp[tr] + (S[now] - S[tr]) * p[now] + q[now]);
}
}
int main(){
vis[0] = 1;
memset(dp ,0x3f ,sizeof dp);
n = read() ,t = read();
for(int i = 2 ; i <= n ; ++i){
fa[i] = read();
LL w = read();
p[i] = read();
q[i] = read();
lim[i] = read();
addedge(i ,fa[i] ,w);
}
getS(1);
getGra(1 ,0);
dp[1] = 0;
solve(1);
for(int i = 2 ; i <= n ; ++i)
printf("%lld\n" ,dp[i]);
return 0;
}