codeforces 739 E. Gosha is hunting
有$n$个神奇宝贝,以及两种神奇宝贝球
用第一种神奇宝贝球去捕捉第$i$个神奇宝贝,捕捉成功的概率为$p_i$
用第二种神奇宝贝球去捕捉第$i$个神奇宝贝,捕捉成功的概率为$q_i$
你有$a$个第一种神奇宝贝球,和$b$个第二种神奇宝贝球
对于每一个神奇宝贝,每一种神奇宝贝球最多只能对其使用一次
安排一种捕捉方式,最大化捕捉的神奇宝贝的个数
$2 \le n \le 2000$
$0 \le a, b \le n$
$0 \le p_i,q_i \le 1$
如果只用第一种球的话,期望会捕捉到$p_i$个第$i$个神奇宝贝
第二种球的话是$q_i$
同时使用的话是$1-(1-p_i)(1-q_i)=p_i+q_i-p_iq_i$
首先可以想到一个十分暴力的$O(n^3)$的动态规划
设$f[i][j][k]$表示前$i$个神奇宝贝,用了$j$个第一种球,和$k$个第二种球的获得神奇宝贝的个数的最大期望
则$f[i][j][k] = \max(f[i - 1][j - 1][k] + p[i], f[i - 1][j][k - 1] + q[i], f[i - 1][j - 1][k - 1] + p[i] + q[i] - p[i]q[i], f[i - 1][j][k])$
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int N = 2010; 4 int n, a, b, t; 5 double p[N], q[N], f[2][N][N]; 6 7 int main() { 8 scanf("%d%d%d", &n, &a, &b); 9 for(int i = 1 ; i <= n ; ++ i) scanf("%lf", &p[i]); 10 for(int i = 1 ; i <= n ; ++ i) scanf("%lf", &q[i]); 11 for(int k = 1 ; k <= n ; ++ k) { 12 t ^= 1; 13 for(int i = 0 ; i <= a ; ++ i) { 14 for(int j = 0 ; j <= b ; ++ j) { 15 16 double x = 0, y = 0, z = 0; 17 if(i) x = f[t ^ 1][i - 1][j] + p[k]; 18 if(j) y = f[t ^ 1][i][j - 1] + q[k]; 19 if(i && j) z = f[t ^ 1][i - 1][j - 1] + p[k] + q[k] - p[k] * q[k]; 20 f[t][i][j] = max(max(x, y), max(z, f[t ^ 1][i][j])); 21 } 22 } 23 } 24 printf("%lf\n", f[t][a][b]); 25 }
考虑费用流建图,建立$S,T,s_a,s_b$节点,以及$x_1 \sim x_n$节点
$S$向$s_a$连流量为$a$,费用为$0$的边
$S$向$s_b$连流量为$b$,费用为$0$的边
$s_a$向$x_i$连流量为$1$,费用为$p_i$的边
$s_b$向$x_i$连流量为$1$,费用为$q_i$的边
$x_i$向$T$连流量为$1$,费用为$0$的边
$x_i$再向$T$连流量为$1$,费用为$-p_iq_i$的边
在这张图上跑最大费用最大流,这样的话,对于每一个神奇宝贝,要么只在只选择一个神奇宝贝球的情况下只获得$p_i$或者$q_i$的贡献,如果同时经过的话,会额外获得$-p_iq_i$的贡献,即$p_i+q_i-p_iq_i$
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int N = 2e5 + 10; 4 const double eps = 1e-8, inf = 1e15; 5 int n, a, b; 6 7 int head[N], rest[N], from[N], to[N], f[N], tot = 1; 8 double c[N]; 9 void add_sig(int u, int v, double c, int f) { 10 from[++ tot] = u, to[tot] = v, :: c[tot] = c, :: f[tot] = f, rest[tot] = head[u], head[u] = tot; 11 } 12 13 void add(int u, int v, double c, int f) { 14 add_sig(u, v, c, f); 15 add_sig(v, u, -c, 0); 16 } 17 18 int Sa, Sb, T, S, pre[N], inq[N]; 19 20 double dis[N], ans, p[N], q[N]; 21 22 int spfa() { 23 queue<int> q; 24 for(int i = 1 ; i <= S ; ++ i) dis[i] = -inf, pre[i] = 0; 25 q.push(S); 26 dis[S] = 0, inq[S] = 1; 27 while(q.size()) { 28 int u = q.front(); q.pop(); inq[u] = 0; 29 for(int i = head[u] ; i ; i = rest[i]) { 30 int v = to[i]; 31 if(f[i] && dis[u] + c[i] > dis[v] + 1e-8) { 32 dis[v] = dis[u] + c[i]; 33 pre[v] = i; 34 if(!inq[v]) inq[v] = 1, q.push(v); 35 } 36 } 37 } 38 return dis[T] > -inf; 39 } 40 41 int main() { 42 scanf("%d%d%d", &n, &a, &b); 43 for(int i = 1 ; i <= n ; ++ i) scanf("%lf", &p[i]); 44 for(int i = 1 ; i <= n ; ++ i) scanf("%lf", &q[i]); 45 Sa = n + 1, Sb = n + 2, T = n + 3, S = n + 4; 46 add(S, Sa, 0, a); 47 add(S, Sb, 0, b); 48 for(int i = 1 ; i <= n ; ++ i) { 49 add(Sa, i, p[i], 1); 50 add(Sb, i, q[i], 1); 51 add(i, T, 0, 1); 52 add(i, T, -p[i] * q[i], 1); 53 } 54 while(spfa()) { 55 int mn = 0x3f3f3f3f; 56 for(int i = pre[T] ; i ; i = pre[from[i]]) mn = min(mn, f[i]); 57 for(int i = pre[T] ; i ; i = pre[from[i]]) f[i] -= mn, f[i ^ 1] += mn; 58 ans += mn * dis[T]; 59 } 60 printf("%lf\n", ans); 61 }
再回头看一下暴力$dp$的转移……发现是一个有限制个数的转移……
不妨想到$wqs$二分……然后就优化到了$O(n^2 \log n)$……
当然可以对于两类的精灵球都进行二分……这样的话就是$O(n \log^2 n)$……
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int N = 2010; 4 int n, a, b, t, g[N][N]; 5 double p[N], q[N], f[N][N]; 6 7 void update(double &f, int &g, double nf, int ng) { if(nf > f) f = nf, g = ng; } 8 #define _f f[i][j] 9 #define _g g[i][j] 10 int sol(double k) { 11 for(int i = 1 ; i <= n ; ++ i) 12 for(int j = 0 ; j <= a ; ++ j) { 13 _f = _g = 0; 14 update(_f, _g, f[i - 1][j], g[i - 1][j]); 15 update(_f, _g, f[i - 1][j] + q[i] - k, g[i - 1][j] + 1); 16 if(j) update(_f, _g, f[i - 1][j - 1] + p[i], g[i - 1][j - 1]); 17 if(j) update(_f, _g, f[i - 1][j - 1] + p[i] + q[i] - p[i] * q[i] - k, g[i - 1][j - 1] + 1); 18 } 19 return g[n][a]; 20 } 21 22 int main() { 23 scanf("%d%d%d", &n, &a, &b); 24 for(int i = 1 ; i <= n ; ++ i) scanf("%lf", &p[i]); 25 for(int i = 1 ; i <= n ; ++ i) scanf("%lf", &q[i]); 26 double l = -1e4, r = 1e4; 27 for(int i = 1 ; i <= 50 ; ++ i) { 28 double mid = (l + r) / 2; 29 if(sol(mid) < b) r = mid; 30 else l = mid; 31 } 32 printf("%lf\n", f[n][a] + l * b); 33 }