wqs二分

wqs二分

用来处理一类带有限制的问题,如恰好选 \(k\) 个,本质是通过二分来规避这个选取数量的限制。

使用前提:原问题具有凹凸性。设 \(g_i\) 表示选 \(i\) 个物品的答案,那么所有 \((i, g_i)\) 点组成一个凸包,满足 \(g'(k)\) 单调。

这类题目通常有以下特点:

  • 如果不限制选的个数,那么很容易求出最优方案。
  • 权值随选的物品增加而单调。

实现

先不考虑恰好选 \(k\) 个的限制,则 DP 的复杂度会降低。

考虑二分一个 \(mid\) ,表示一次选取附加的权值。则选取的次数越多,附加权值越大,选的就会越多/越少,根据选的数量来调整 \(mid\) ,最后调整到恰好选 \(k\) 个时减掉附加权值即为答案。

本质就是二分斜率使得凸包切线切点在 \(x = k\) 处,检查函数返回的是截距。

细节

如果斜率为整数时,答案的切线会同时切多个点,此时就会出现 \(mid\) 时切到的 \(x\) 坐标小于 \(k\) ,而 \(mid + 1\) 时切到的 \(x\) 坐标大于 \(k\)

注意到一个点对应的切线斜率是一段区间,而且我们只要保证这个斜率会切到 \(k\) 即可。由于 \(g(k + 1) - g(k)\) 为整数,所以这个斜率一定存在。

实现时每次判定时能选就选,这样就会切到最右边的点,此时为了让斜线的斜率逼近 \(g'(k)\) ,于是就要根据凸壳形状减小(下凸壳)或增大(上凸壳)切线斜率,并更新答案。

构造解

对于 wqs 二分能处理的大部分问题,其 DP 过程中的答案仍然是凸的。

如对于序列分段问题,DP 的就是每一个前缀的切点,每个切点的坐标是由前面的切点转移过来的。

所以对于每一个 DP 的子问题,我们都记录切线所切的区间。这样我们可以从最后开始,每次尝试找到上一个可行的转移点满足 \(k\) 要落在其切点的区间内。

不妨用 \([l_i, r_i]\) 记录第 \(i\) 个凸包的切线区间,那么我们只需要找到 \(k \in [l_j, r_j]\) 并且可以转移当前位置的 \(j\) 即可。

应用

P5633 最小度限制生成树 MST Company

\(deg_s = k\) 的情况下的最小生成树。

\(n \leq 5 \times 10^4, m \leq 5 \times 10^5, k \leq 100\)

二分一个 \(mid\) ,与 \(s\) 连接的边的边权都加上这个 \(mid\) ,则最终MST中与 \(s\) 连接的边的数量受 \(mid\) 影响,于是可以调整使得恰好 \(deg_s =k\)

时间复杂度 \(O(m \log m \log V)\)

P2619 [国家集训队] Tree I 做法也是类似的。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e4 + 7, M = 5e5 + 7;

struct Edge {
	int u, v, w, c;
	
	inline bool operator < (const Edge &rhs) const {
		return w == rhs.w ? c > rhs.c : w < rhs.w;
	}
} e[M];

struct DSU {
	int fa[N];

	inline void prework(int n) {
		iota(fa +1 , fa + 1 + n, 1);
	}

	inline int find(int x) {
	    while (x != fa[x])
	        fa[x] = fa[fa[x]], x = fa[x];

	    return x;
	}

	inline void merge(int x, int y) {
	    fa[find(y)] = find(x);
	}
} dsu;

int n, m, s, k, tot;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = (c == '-');
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= (c == '-');
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

inline pair<ll, int> Kruskal() {
	sort(e + 1, e + 1 + m);
	dsu.prework(n);
	pair<ll, int> ans = make_pair(0ll, 0);
	
	for (int i = 1; i <= m; ++i) {
		if (dsu.find(e[i].u) == dsu.find(e[i].v))
			continue;
		
		dsu.merge(e[i].u, e[i].v);
		ans.first += e[i].w, ans.second += e[i].c;
	}
	
	return ans;
}

inline pair<ll, int> check(int lambda) {
	for (int i = 1; i <= m; ++i)
		if (e[i].c)
			e[i].w += lambda;
	
	pair<ll, int> ans = Kruskal();

	for (int i = 1; i <= m; ++i)
		if (e[i].c)
			e[i].w -= lambda;
	
	return ans;
}

signed main() {
	n = read(), m = read(), s = read(), k = read();
	dsu.prework(n);
	
	for (int i = 1; i <= m; ++i) {
		e[i].u = read(), e[i].v = read(), e[i].w = read(), e[i].c = (e[i].u == s || e[i].v == s);
		dsu.merge(e[i].u, e[i].v);
	}

	for (int i = 2; i <= n; ++i)
		if (dsu.find(i) != dsu.find(1))
			return puts("Impossible"), 0;
	
	if (check(-inf).second < k || check(inf).second > k)
		return puts("Impossible"), 0;
		
	int l = -inf, r = inf, ans = 0;
	
	while (l <= r) {
		int mid = (l + r) >> 1;
		
		if (check(mid).second >= k)
			l = mid + 1, ans = mid;
		else
			r = mid - 1;
	}
	
	printf("%lld\n", check(ans).first - 1ll * k * ans);
	return 0;
}

P4983 忘情

定义一段序列 \(a_{l \sim r}\) 的价值为:

\[((\sum_{i = l}^r a_i) + 1)^2 \]

\(a_{1 \sim n}\) 分为 \(m\) 段,求每段价值和的最小值。

\(n \leq 10^5\)

首先由于 \((a + b)^2 \geq a^2 + b^2\) ,所以要分的段数要越多越好,答案关于 \(m\) 增加而减小,且因为先选减小值更小的地方分开更优,具有凸性。

于是可以 wqs 二分,判定部分可以斜率优化,时间复杂度 \(O(n \log V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

ll a[N], s[N], f[N];
int q[N], g[N];

int n, m;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = (c == '-');
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= (c == '-');
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

inline double slope(int i, int j) {
	return ((f[i] + s[i] * s[i]) - (f[j] + s[j] * s[j])) / (s[i] - s[j]);
}

inline pair<ll, int> check(ll mid) {
	for (int i = 1, head = 0, tail = 0; i <= n; ++i) {
		while (head < tail && slope(q[head], q[head + 1]) < 2.0 * (s[i] + 1))
			++head;
		
		f[i] = f[q[head]] + (s[i] - s[q[head]] + 1) * (s[i] - s[q[head]] + 1) + mid;
		g[i] = g[q[head]] + 1;
		
		while (head < tail && slope(q[tail - 1], q[tail]) > slope(q[tail], i))
			--tail;
		
		q[++tail] = i;
	}
	
	return make_pair(f[n], g[n]);
}

signed main() {
	n = read(), m = read();
	
	for (int i = 1; i <= n; ++i)
		s[i] = s[i - 1] + (a[i] = read());
	
	ll l = 0, r = 1e18, ans = 0;
	
	while (l <= r) {
		ll mid = (l + r) >> 1;
		
		if (check(mid).second <= m)
			r = mid - 1, ans = mid;
		else
			l = mid + 1;
	}
	
	printf("%lld", check(ans).first - ans * m);
	return 0;
}

CF802O April Fools' Problem (hard)

\(n\) 道题,第 \(i\) 天可以花费 \(a_i\) 准备一道题或花费 \(b_i\) 打印一道题或什么也不做,准备的题可以留到以后打印,求最少花费使得准备并打印 \(k\) 道题。

\(n \leq 5 \times 10^5\)

wqs二分配合反悔贪心即可,时间复杂度 \(O(n \log n \log V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;

ll a[N], b[N];

int n, k;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = (c == '-');
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= (c == '-');
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

inline pair<ll, int> check(ll lambda) {
	ll ans = 0;
	int cnt = 0;
	priority_queue<pair<ll, int> > q;
	
	for (int i = 1; i <= n; ++i)
		b[i] -= lambda;
	
	for (int i = 1; i <= n; ++i) {
		q.emplace(make_pair(-a[i], 1));
		
		if (-b[i] > -q.top().first) {
			ans += b[i] - q.top().first, cnt += q.top().second;
			q.pop(), q.emplace(make_pair(b[i], 0));
		}
	}
	
	for (int i = 1; i <= n; ++i)
		b[i] += lambda;
	
	return make_pair(ans, cnt);
}

signed main() {
	n = read(), k = read();
	
	for (int i = 1; i <= n; ++i)
		a[i] = read();
	
	for (int i = 1; i <= n; ++i)
		b[i] = read();
	
	ll l = -1e18, r = 1e18, ans = 1e18;
	
	while (l <= r) {
		ll mid = (l + r) >> 1;
		
		if (check(mid).second >= k)
			ans = mid, r = mid - 1;
		else
			l = mid + 1;
	}
	
	printf("%lld", check(ans).first + ans * k);
	return 0;
}

CF739E Gosha is hunting

现在一共有 \(n\) 只神奇宝贝,你有 \(a\) 个宝贝球和 \(b\) 个超级球,宝贝球抓到第 \(i\) 只神奇宝贝的概率是 \(p_i\),超级球抓到的概率则是 \(u_i\)

不能往同一只神奇宝贝上使用超过一个同种的球,但是可以往同一只上既使用宝贝球又使用超级球(都抓到算一个)。

求合理分配下抓到神奇宝贝的总个数期望的最大值。

\(n \leq 2000\)

首先球全部用完一定最优。考虑 wqs 二分斜率 \(\lambda\) ,每用一个超级球答案就减去 \(\lambda\) ,DP 出超级球选 \(b\) 个的方案。

直接 DP 是 \(O(n^2 \log V)\) 的,但是可以进一步优化,考虑每个位置的四种选择:

  • 用宝贝球和超级球:\(p + u - pu - \lambda\)
  • 用宝贝球:\(p\)
  • 用超级球: \(u - \lambda\)
  • 不用球:\(0\)

考虑一个位置从不用宝贝球到用宝贝球,答案增加量就是:

\[\max(p + u - pu - \lambda, p) - \max(u - \lambda, 0) \]

对这个增加量排序后贪心即可,时间复杂度 \(O(n \log n \log V)\)

#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-9;
const int N = 2e3 + 7;

struct Node {
	double x;
	int kb, ka;

	inline bool operator < (const Node &rhs) const {
		return x > rhs.x;
	}
} nd[N];

double p[N], u[N];

int n, a, b;

inline pair<double, int> check(double lambda) {
	double ans = 0;
	int cnt = 0;

	for (int i = 1; i <= n; ++i) {
		if (p[i] + u[i] - p[i] * u[i] - lambda > p[i])
			nd[i].ka = 1, nd[i].x = p[i] + u[i] - p[i] * u[i] - lambda;
		else
			nd[i].ka = 0, nd[i].x = p[i];

		if (u[i] - lambda > 0)
			nd[i].kb = 1, nd[i].x -= u[i] - lambda, ans += u[i] - lambda;
		else
			nd[i].kb = 0;
	}

	sort(nd + 1, nd + 1 + n);

	for (int i = 1; i <= a; ++i)
		ans += nd[i].x, cnt += nd[i].ka;

	for (int i = a + 1; i <= n; ++i)
		cnt += nd[i].kb;

	return make_pair(ans, cnt);
}

signed main() {
	scanf("%d%d%d", &n, &a, &b);

	for (int i = 1; i <= n; ++i)
		scanf("%lf", p + i);

	for (int i = 1; i <= n; ++i)
		scanf("%lf", u + i);

	double l = 0, r = 1, ans = 1;

	while (r - l > eps) {
		double mid = (l + r) / 2;

		if (check(mid).second >= b)
			l = mid;
		else
			ans = r = mid;
	}

	printf("%.6lf", check(ans).first + ans * b);
	return 0;
}

P4383 【八省联考2018】林克卡特树

给定一棵 \(n\) 个点的树,在树上选出 \(k + 1\) 条互不相交链,最大化其权值之和。

\(n \leq 3 \times 10^5\)

根据 wqs 二分的经典套路,下面考虑没有 \(k\) 的限制怎么做。

\(f_{u, 0/1/2}\) 表示考虑到 \(u\)\(deg_u = 0/1/2\) 时的答案,更进一步的:

  • \(deg = 0\) :这个点没有连边。
  • \(deg = 1\) :这个点连着一条未完成的链,该链还未计入答案。
  • \(deg = 2\) :这个点连着一条连接两个不同子树的链。

首先约定在每个节点的全部转移结束时,进行一次更新:

\[f_{u, 0} = \max(f_{u, 0}, f_{u, 1}, f_{u, 2}) \]

这样就把 \(u\) 的全部最优解统计了出来。对于 \(deg = 0/2\) 的情况可以直接合并,而对于 \(deg = 1\) 的情况要先把这条链结束掉并统计答案。

则有转移方程:

\[f_{u, 2} = \max \begin{cases} f_{u, 2} + f_{v, 0} \\ f_{u, 1} + w(u, v) + f_{v, 1} \\ f_{u, 2} \end{cases} \]

第一行表示 \(u\) 不接到 \(v\) 上,直接继承 \(v\) 的最优解。

第二行表示 \(u\) 接到 \(v\) 上,继承 \(v\) 一条未完成的链,得到一条完成的链。

\[f_{u, 1} = \max \begin{cases} f_{u, 1} + f_{v, 0} \\ f_{u, 0} + w(u, v) + f_{v, 1} \\ f_{u, 1} \end{cases} \]

第一行表示 \(u\) 不接到 \(v\) 上,直接继承 \(v\) 的最优解。

第二行表示 \(u\) 接到 \(v\) 上,继承 \(v\) 一条未完成的链,得到一条未完成的链

\[f_{u, 0} = \max \begin{cases} f_{u, 0} + f_{v, 0} \\ f_{u, 0} \end{cases} \]

这里 \(u\) 必须不接 \(v\) ,只能取 \(v\) 的最优解。

时间复杂度 \(O(n \log V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 3e5 + 7;

struct Graph {
	vector<pair<int, int> > e[N];
	
	inline void insert(int u, int v, int w) {
		e[u].emplace_back(v, w);
	}
} G;

pair<ll, int> f[N][3];

int n, k;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = (c == '-');
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= (c == '-');
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

inline pair<ll, int> operator + (const pair<ll, int> &a, const pair<ll, int> &b) {
	return make_pair(a.first + b.first, a.second + b.second);
}

inline pair<ll, int> operator + (const pair<ll, int> &a, const ll &b) {
	return make_pair(a.first + b, a.second);
}

void dfs(int u, int fa, const pair<ll, int> lambda) {
	f[u][0] = f[u][1] = f[u][2] = make_pair(0, 0);
	f[u][2] = max(f[u][2], lambda);

	for (auto it : G.e[u]) {
		int v = it.first, w = it.second;

		if (v == fa)
			continue;

		dfs(v, u, lambda);
		f[u][2] = max(f[u][2], max(f[u][2] + f[v][0], f[u][1] + w + f[v][1] + lambda));
		f[u][1] = max(f[u][1], max(f[u][1] + f[v][0], f[u][0] + w + f[v][1]));
		f[u][0] = max(f[u][0], f[u][0] + f[v][0]);
	}

	f[u][0] = max(f[u][0], max(f[u][1] + lambda, f[u][2]));
}

inline pair<ll, int> check(ll lambda) {
	dfs(1, 0, make_pair(lambda, 1));
	return f[1][0];
}

signed main() {
	n = read(), k = read() + 1;

	for (int i = 1; i < n; ++i) {
		int u = read(), v = read(), w = read();
		G.insert(u, v, w), G.insert(v, u, w);
	}

	ll l = -1e12, r = 1e12, ans = 0;

	while (l <= r) {
		ll mid = (l + r) >> 1;

		if (check(mid).second >= k)
			ans = mid, r = mid - 1;
		else
			l = mid + 1;
	}

	printf("%lld", check(ans).first - ans * k);
	return 0;
}

P6246 [IOI2000] 邮局 加强版 加强版

\(n\) 个村庄,放 \(m\) 个邮局,求每个村庄到最近邮局的距离和的最小值。

\(m \leq n \leq 5 \times 10^5\)

\(f_{i, j}\) 表示前 \(i\) 个村庄放 \(j\) 个邮局的最小距离和,\(w(l, r)\) 表示在 \([l, r]\) 范围村庄放一个邮局的最小距离和,则有:

\[f_{i, j} = \min_{k = 0}^{i - 1} \{ f_{k, j - 1} + w(k + 1, i) \} \]

决策单调性优化做到 \(O(n^2)\)

考虑用 wqs 二分规避 \(j\) 的限制,于是得到一个1D/1D 的 DP,并且也有决策单调性,可以二分栈做到 \(O(n \log n \log V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;

struct Node {
	int j, l, r;
} q[N];

ll s[N], f[N];
int a[N], g[N];

int n, m;

template <class T = int>
inline T read() {
	char c = getchar();
	bool sign = (c == '-');
	
	while (c < '0' || c > '9')
		c = getchar(), sign |= (c == '-');
	
	T x = 0;
	
	while ('0' <= c && c <= '9')
		x = (x << 1) + (x << 3) + (c & 15), c = getchar();
	
	return sign ? (~x + 1) : x;
}

inline ll w(int l, int r) {
	int mid = (l + r) >> 1;
	return (s[r] - s[mid]) - 1ll * a[mid] * (r - mid) + 1ll * a[mid] * (mid - l + 1) - (s[mid] - s[l - 1]);
}

inline ll calc(int i, int j) {
	return f[j] + w(j + 1, i);
}

inline int BinarySearch(int l, int r, int i, int j) {
	int ans = r + 1;

	while (l <= r) {
		int mid = (l + r) >> 1;

		if (calc(mid, i) <= calc(mid, j))
			ans = mid, r = mid - 1;
		else
			l = mid + 1;
	}

	return ans;
}

inline pair<ll, int> check(ll lambda) {
	int head = 1, tail = 0;
	q[++tail] = (Node) {0, 1, n};

	for (int i = 1; i <= n; ++i) {
		if (q[head].r == i - 1)
			++head;

		f[i] = calc(i, q[head].j) + lambda, g[i] = g[q[head].j] + 1;
		int pos = n + 1;

		while (head <= tail) {
			if (calc(q[tail].l, i) <= calc(q[tail].l, q[tail].j))
				pos = q[tail--].l;
			else {
				pos = BinarySearch(q[tail].l, q[tail].r, i, q[tail].j);
				q[tail].r = pos - 1;
				break;
			}
		}

		if (pos != n + 1)
			q[++tail] = (Node) {i, pos, n};
	}

	return make_pair(f[n], g[n]);
}

signed main() {
	n = read(), m = read();

	for (int i = 1; i <= n; ++i)
		a[i] = read();

	sort(a + 1, a + 1 + n);

	for (int i = 1; i <= n; ++i)
		s[i] = s[i - 1] + a[i];

	ll l = 0, r = 1e12, ans = 0;

	while (l <= r) {
		ll mid = (l + r) >> 1;

		if (check(mid).second >= m)
			ans = mid, l = mid + 1;
		else
			r = mid - 1;
	}

	printf("%lld", check(ans).first - ans * m);
	return 0;
}
posted @ 2024-08-10 11:15  我是浣辰啦  阅读(15)  评论(0编辑  收藏  举报