【题解】 CF1534G A New Beginning dp+函数键值化
不难发现:总存在一种最优方案使得每一个土豆都在 \(|X-x|=|Y-y|\) 的时候发出。
那么对于每一个土豆都有一个 \(\times\) 形状的触发区间,但实际上我们只要保留形如 \
的,形如 /
总会劣于前者。
所以我们只需要在经过每个土豆的 \(y=-x+b\) 这条直线上进行贡献计算就行了。
现在我们对于土豆按 \(x+y\) 进行分层。可以对于按层进行一个 \(dp\),我们设 \(dp_{i,j}\) 表示走到 \((j,i-j)\) (即第 \(i\) 层 \(x=j\) 的位置)的最小代价。
其实可以把第 \(i\) 层的 \(dp\) 数组看成一个函数,它经过 \((j,dp_{i,j})\),我们做 \(dp\) 转移实际上就是函数值的变动。
我们考虑一下 \(dp_{i}\) 这个函数到 \(dp_{k}\) 时产生的变化,其中 \(i,k\) 是两个相邻的有土豆的层。
这个变化分两步:
- 首先第 \(k\) 层是由第 \(i\) 层的某些位置转移过来的。
注意这里转移要求 \(j-t\) 是一个合法的位置。
- 然后我们再给 \(dp_{k}\) 的每个位置加上这一层发射土豆的代价,不难发现这个代价函数 \(f_k(x)\) 是关于 \(x\) 的一个凸函数。
所以说我们的 \(dp\) 数组在任意时刻也是一个凸函数,所以第一步中我们的 \(dp\) 转移就非常好完成。
所以算法流程是:
- 我们找到函数最低点 \(x_0\)(如果存在多个最低点,即一条斜率为 \(0\) 的线段,则选择横坐标尽可能大的),将当前凸函数从这个最低点切开,右半部分向右平移 \(k-i\) 格。中间缺失的部分补成函数最低点的值(即 \(dp_{k,x_0}\))。
- 将函数加上发射土豆的代价函数 \(f_k(x)\)。
具体实现可以直接维护 \(dp\) 数组的导数(也就是斜率)的差分,记录在哪些 \(x\) 斜率有增加,增加了多少。记录的这些地方就是凸函数的所有“凸起”的位置。
至于找最低点可以二分。
至于加上代价函数,可以直接暴力修改,因为代价函数也是凸的,且它的折线段不会超过 \(O(n)\) 段。
发现可以用个 splay 维护这些信息。
细节:每个土豆最多新增 \(1\) 个折线段,每次函数平移最多增加 \(1\) 个折线段,故最多有 \(2n\) 个折线段,记得开两倍数组。
UPD1:现在我发现,最小值的位置移动次数是 \(O(n)\) 的,可以直接拿对顶堆来维护,就不需要写 splay 了……
可怜的代码
#include <bits/stdc++.h>
#define debug(...) ;//fprintf(stderr ,__VA_ARGS__)
#define __FILE(x)\
freopen(#x".in" ,"r" ,stdin);\
freopen(#x".out" ,"w" ,stdout)
#define LL long long
const int MX = (8e5 + 23) * 2;
const LL MOD = 998244353;
int read(){
char k = getchar(); int x = 0;
while(k < '0' || k > '9') k = getchar();
while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
return x;
}
struct POINT{
int x ,y;
bool operator <(const POINT& B)const{
return x + y == B.x + B.y ? x < B.x : x + y < B.x + B.y;
}
}A[MX];
namespace SPLAY{
#define lch(x) ch[x][0]
#define rch(x) ch[x][1]
int ch[MX][2] ,fa[MX] ,root ,tot;
int add[MX] ,key[MX] ,size[MX] ,cnt[MX];
int get(int x){return x == rch(fa[x]);}
int Nroot(int x){return get(x) || x == rch(fa[x]);}
void pushup(int x){size[x] = size[lch(x)] + size[rch(x)] + cnt[x];}
void doadd(int x ,int v){
add[x] += v ,key[x] += v;
}
void pushdown(int x){
if(add[x]){
if(lch(x)) doadd(lch(x) ,add[x]);
if(rch(x)) doadd(rch(x) ,add[x]);
add[x] = 0;
}
}
void rotate(int x){
int f = fa[x] ,gf = fa[f] ,which = get(x) ,W = ch[x][!which];
if(Nroot(f)) ch[gf][get(f)] = x;
ch[x][!which] = f ,ch[f][which] = W;
if(W) fa[W] = f;
fa[f] = x ,fa[x] = gf ,pushup(f) ,pushup(x);
}
int stk[MX] ,dep;
void splay(int x ,int goal = 0){
int f = x; stk[++dep] = f;
while(fa[f] != goal) stk[++dep] = f = fa[f];
while(dep) pushdown(stk[dep--]);
while((f = fa[x]) != goal){
if(fa[f] != goal) rotate(get(x) == get(f) ? f : x);
rotate(x);
}if(!goal) root = x;
}
void insert(int val ,int QWQWQ ){
int x = root ,f = 0;
while(x && val != key[x]){
pushdown(x);
f = x ,x = ch[x][val > key[x]];
}
if(x && val == key[x]){
cnt[x] += QWQWQ;
size[x] += QWQWQ;
return splay(x);
}
x = ++tot ,fa[x] = f;
size[x] = cnt[x] = QWQWQ ,key[x] = val;
if(f) ch[f][val > key[f]] = x ,pushup(f);
return splay(x);
}
std::pair<int ,int> lower_bound(int value){
// 第一个数字返回结点标号
// 第二个数字返回斜率
int x = root ,sum = 0;
while(x){
pushdown(x);
if(lch(x) && sum + size[lch(x)] >= value) x = lch(x);
else{
sum += cnt[x] + (lch(x) ? size[lch(x)] : 0);
if(sum >= value) return {x ,sum};
x = rch(x);
}
}
assert(0);
return {-1 ,-1};
}
int prev(int x){
splay(x);
x = lch(x);
while(rch(x)){
pushdown(x);
x = rch(x);
}
return x;
}
int next(int x){
splay(x);
x = rch(x);
while(lch(x)){
pushdown(x);
x = lch(x);
}
return x;
}
void output(int x){
if(!x) return;
pushdown(x);
output(lch(x));
debug("pos: %d delk = %d\n" ,key[x] ,cnt[x]);
output(rch(x));
}
}using namespace SPLAY;
int main(){
int n = read();
const int MXV = 1e9 + 1;
for(int i = 1 ,x ,y ; i <= n ; ++i){
x = read() ,y = read();
A[i] = (POINT){x ,y};
if(x + y == 0){
--i;
--n;
}
}
if(!n) return puts("0") ,0;
insert(INT_MIN ,0);
std::vector<int> addition;
std::sort(A + 1 ,A + 1 + n);
int minpos = 0 ,las = 0;
for(int i = 1 ,j ; i <= n ; i = j){
j = i;
addition.clear();
while(j <= n && A[i].x + A[i].y == A[j].x + A[j].y){
addition.push_back(A[j].x);
++j;
}
if(i != 1){
std::pair<int ,int> tmp = lower_bound(i - 1);
if(tmp.second == i - 1){
tmp = lower_bound(i);
tmp = {prev(tmp.first) ,i - 1};
}
minpos = lower_bound(i - 1).first;
splay(minpos);
int kksk = prev(minpos);
splay(kksk) ,splay(minpos ,kksk);
int del = A[i].x + A[i].y - las;
int location = key[minpos];
int psum = tmp.second - cnt[minpos];
int k = psum - (i - 1);
doadd(minpos ,del);
insert(location ,-k);
insert(location + del ,k);
}
for(auto k : addition){
insert(k ,2);
}
las = A[i].x + A[i].y;
}
LL sum = 0;
for(int i = 1 ; i <= n ; ++i)
sum += A[i].x;
insert(0 ,0);
int cur = next(lower_bound(-1).first);
las = 0;
int k = -n;
while(1){
k += cnt[cur];
if(k >= 0) break;
int nxt = next(cur);
sum += 1LL * k * (key[nxt] - key[cur]);
cur = nxt;
}
printf("%lld\n" ,sum);
return 0;
}