【题解】 CF1534G A New Beginning dp+函数键值化

不难发现:总存在一种最优方案使得每一个土豆都在 \(|X-x|=|Y-y|\) 的时候发出。

那么对于每一个土豆都有一个 \(\times\) 形状的触发区间,但实际上我们只要保留形如 \ 的,形如 / 总会劣于前者。

所以我们只需要在经过每个土豆的 \(y=-x+b\) 这条直线上进行贡献计算就行了。

现在我们对于土豆按 \(x+y\) 进行分层。可以对于按层进行一个 \(dp\),我们设 \(dp_{i,j}\) 表示走到 \((j,i-j)\) (即第 \(i\)\(x=j\) 的位置)的最小代价。

其实可以把第 \(i\) 层的 \(dp\) 数组看成一个函数,它经过 \((j,dp_{i,j})\),我们做 \(dp\) 转移实际上就是函数值的变动。

我们考虑一下 \(dp_{i}\) 这个函数到 \(dp_{k}\) 时产生的变化,其中 \(i,k\) 是两个相邻的有土豆的层。

这个变化分两步:

  1. 首先第 \(k\) 层是由第 \(i\) 层的某些位置转移过来的。

\[dp_{k,j}= \min_{t=0}^{k-i} dp_{i,j-t} \]

注意这里转移要求 \(j-t\) 是一个合法的位置。

  1. 然后我们再给 \(dp_{k}\) 的每个位置加上这一层发射土豆的代价,不难发现这个代价函数 \(f_k(x)\) 是关于 \(x\) 的一个凸函数。

所以说我们的 \(dp\) 数组在任意时刻也是一个凸函数,所以第一步中我们的 \(dp\) 转移就非常好完成。

所以算法流程是:

  1. 我们找到函数最低点 \(x_0\)(如果存在多个最低点,即一条斜率为 \(0\) 的线段,则选择横坐标尽可能大的),将当前凸函数从这个最低点切开,右半部分向右平移 \(k-i\) 格。中间缺失的部分补成函数最低点的值(即 \(dp_{k,x_0}\))。
  2. 将函数加上发射土豆的代价函数 \(f_k(x)\)

具体实现可以直接维护 \(dp\) 数组的导数(也就是斜率)的差分,记录在哪些 \(x\) 斜率有增加,增加了多少。记录的这些地方就是凸函数的所有“凸起”的位置。

至于找最低点可以二分。

至于加上代价函数,可以直接暴力修改,因为代价函数也是凸的,且它的折线段不会超过 \(O(n)\) 段。

发现可以用个 splay 维护这些信息。

细节:每个土豆最多新增 \(1\) 个折线段,每次函数平移最多增加 \(1\) 个折线段,故最多有 \(2n\) 个折线段,记得开两倍数组。

UPD1:现在我发现,最小值的位置移动次数是 \(O(n)\) 的,可以直接拿对顶堆来维护,就不需要写 splay 了……

可怜的代码
#include <bits/stdc++.h>

#define debug(...) ;//fprintf(stderr ,__VA_ARGS__)
#define __FILE(x)\
	freopen(#x".in" ,"r" ,stdin);\
	freopen(#x".out" ,"w" ,stdout)
#define LL long long

const int MX = (8e5 + 23) * 2;
const LL MOD = 998244353;

int read(){
	char k = getchar(); int x = 0;
	while(k < '0' || k > '9') k = getchar();
	while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
	return x;
}

struct POINT{
	int x ,y;
	bool operator <(const POINT& B)const{
		return x + y == B.x + B.y ? x < B.x : x + y < B.x + B.y;
	}
}A[MX];

namespace SPLAY{
#define lch(x) ch[x][0]
#define rch(x) ch[x][1]
	int ch[MX][2] ,fa[MX] ,root ,tot;
	int add[MX] ,key[MX] ,size[MX] ,cnt[MX];
	int get(int x){return x == rch(fa[x]);}
	int Nroot(int x){return get(x) || x == rch(fa[x]);}
	void pushup(int x){size[x] = size[lch(x)] + size[rch(x)] + cnt[x];}
	void doadd(int x ,int v){
		add[x] += v ,key[x] += v;
	}
	void pushdown(int x){
		if(add[x]){
			if(lch(x)) doadd(lch(x) ,add[x]);
			if(rch(x)) doadd(rch(x) ,add[x]);
			add[x] = 0;
		}
	}
	void rotate(int x){
		int f = fa[x] ,gf = fa[f] ,which = get(x) ,W = ch[x][!which];
		if(Nroot(f)) ch[gf][get(f)] = x;
		ch[x][!which] = f ,ch[f][which] = W;
		if(W) fa[W] = f;
		fa[f] = x ,fa[x] = gf ,pushup(f) ,pushup(x);
	}
	int stk[MX] ,dep;
	void splay(int x ,int goal = 0){
		int f = x; stk[++dep] = f;
		while(fa[f] != goal) stk[++dep] = f = fa[f];
		while(dep) pushdown(stk[dep--]);
		while((f = fa[x]) != goal){
			if(fa[f] != goal) rotate(get(x) == get(f) ? f : x);
			rotate(x);
		}if(!goal) root = x;
	}
	void insert(int val ,int QWQWQ ){
		int x = root ,f = 0;
		while(x && val != key[x]){
			pushdown(x);
			f = x ,x = ch[x][val > key[x]];
		}
		if(x && val == key[x]){
			cnt[x] += QWQWQ;
			size[x] += QWQWQ;
			return splay(x);
		}
		x = ++tot ,fa[x] = f;
		size[x] = cnt[x] = QWQWQ ,key[x] = val;
		if(f) ch[f][val > key[f]] = x ,pushup(f);
		return splay(x);
	}
	std::pair<int ,int> lower_bound(int value){
		// 第一个数字返回结点标号
		// 第二个数字返回斜率
		int x = root ,sum = 0;
		while(x){
			pushdown(x);
			if(lch(x) && sum + size[lch(x)] >= value) x = lch(x);
			else{
				sum += cnt[x] + (lch(x) ? size[lch(x)] : 0);
				if(sum >= value) return {x ,sum};
				x = rch(x);
			}
		}
		assert(0);
		return {-1 ,-1};
	}
	int prev(int x){
		splay(x);
		x = lch(x);
		while(rch(x)){
			pushdown(x);
			x = rch(x);
		}
		return x;
	}
	int next(int x){
		splay(x);
		x = rch(x);
		while(lch(x)){
			pushdown(x);
			x = lch(x);
		}
		return x;
	}
	void output(int x){
		if(!x) return;
		pushdown(x);
		output(lch(x));
		debug("pos: %d delk = %d\n" ,key[x] ,cnt[x]);
		output(rch(x));
	}
}using namespace SPLAY;

int main(){
	int n = read();
	const int MXV = 1e9 + 1;
	for(int i = 1 ,x ,y ; i <= n ; ++i){
		x = read() ,y = read();
		A[i] = (POINT){x ,y};
		if(x + y == 0){
			--i;
			--n;
		}
	}

	if(!n) return puts("0") ,0;
	
	insert(INT_MIN ,0);

	std::vector<int> addition;
	std::sort(A + 1 ,A + 1 + n);

	int minpos = 0 ,las = 0;
	for(int i = 1 ,j  ; i <= n ; i = j){
		j = i;
		addition.clear();
		while(j <= n && A[i].x + A[i].y == A[j].x + A[j].y){
			addition.push_back(A[j].x);
			++j;
		}
		if(i != 1){
			std::pair<int ,int> tmp = lower_bound(i - 1);
			if(tmp.second == i - 1){
				tmp = lower_bound(i);
				tmp = {prev(tmp.first) ,i - 1};
			}
			minpos = lower_bound(i - 1).first;
			splay(minpos);
			
			int kksk = prev(minpos);
			splay(kksk) ,splay(minpos ,kksk);
			int del = A[i].x + A[i].y - las;

			int location = key[minpos];
			int psum = tmp.second - cnt[minpos];
			int k = psum - (i - 1);

			doadd(minpos ,del);
			insert(location ,-k);
			insert(location + del ,k);
		}
		for(auto k : addition){
			insert(k ,2);
		}
		las = A[i].x + A[i].y;
	}
	
	LL sum = 0;
	for(int i = 1 ; i <= n ; ++i)
		sum += A[i].x;
	insert(0 ,0);
	int cur = next(lower_bound(-1).first);
	las = 0;
	int k = -n;
	while(1){
		k += cnt[cur];
		if(k >= 0) break;
		int nxt = next(cur);
		sum += 1LL * k * (key[nxt] - key[cur]);
		cur = nxt;
		
	}
	printf("%lld\n" ,sum);
	return 0;
}
posted @ 2021-06-25 17:05  Imakf  阅读(175)  评论(1编辑  收藏  举报