Codeforces 1188D Make Equal DP
题意:给你个序列,你可以给某个数加上2的幂次,问最少多少次可以让所有的数相等。
思路(官方题解):我们先给序列排序,假设bit(c)为c的二进制数中1的个数,假设所有的数最后都成为了x, 显然x >= a[n],那么最后的总花费为Σbit(x - a[i])。不妨假设x = t + a[n], b[i] = a[n] - a[i], 那么问题转化为了求Σbit(t + b[i])的最小值。我们假设最后取得最小值的数是x。我们假设已经知道了填充x的低k - 1位的最小花费,我们来考虑填充第k位。我们发现,知道第k位对答案的贡献需要知道以下3项:1:b序列在第k位的二进制位。2:低位的进位。3:这一位x是填0还是填1。对于第一项,我们已经知道了。第三项在dp的时候直接判断转移就行,现在的问题是第二项。乍一看,我们需要2 ^ n个状态来表示上一位的进位情况。仔细想一想发现,因为这一位要么填0,要么填1,即要么不加数,要么加同一个数。所以,我们对所有的数取模 2 ^ k后从小到大排序,最后产生进位的位置是一个后缀,这样我们就可以把状态数降成O(n)个。我们假设dp[i][j]为填x的第i位,上一位产生进位的后缀长度为j的最小花费,转移的时候判断这位填0还是填1就可以了。dp的复杂度是O(n * log(n) * log(max(a)), 为什么是log(max(a))呢?因为t的大小不会大于b[1] = a[n] - a[1]。官方题解有证明,但是我们可以直观感受一下这是对的:当t正好等于b[1]时, 相当于是b[1] << 1, 之后的操作会重复,并且不会使答案更小。
采用基于二进制数的基数排序可以使复杂度进一步降低为O(n * log(max(a))。
代码(基数排序版本):
#include <bits/stdc++.h> #define LL long long using namespace std; const int maxn = 100010; LL dp[70][maxn]; LL a[maxn], f[maxn], f1[maxn]; int cnt; int sum0[maxn], sum1[maxn], tot0, tot1, n; void Sort(LL x) { cnt = 0; for (int i = 1; i <= n; i++) if(((a[f[i]] >> x) & 1) == 0) f1[++cnt] = f[i]; for (int i = 1; i <= n; i++) if((a[f[i]] >> x) & 1) f1[++cnt] = f[i]; for (int i = 1; i <= n; i++) f[i] = f1[i]; } void init(LL x) { for (int i = 1; i <= n; i++) { sum0[i] = sum0[i - 1]; sum1[i] = sum1[i - 1]; if((a[f[i]] >> x) & 1) sum1[i]++; else sum0[i]++; } tot1 = sum1[n], tot0 = sum0[n]; } int main() { LL mx = 0; scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%lld", &a[i]); mx = max(mx, a[i]); } for (int i = 1; i <= n; i++) { a[i] = mx - a[i]; } memset(dp, 0x3f, sizeof(dp)); for (int i = 1; i <= n; i++) f[i] = i; dp[0][0] = 0; for (LL i = 0; i < 62; i++) { if(i != 0) { Sort(i - 1); } init(i); for (int j = 0; j <= n; j++) { for (int bit = 0; bit < 2; bit++) { if(bit == 0) { int val = sum1[n - j] + tot0 - sum0[n - j], Next_state = tot1 - sum1[n - j]; dp[i + 1][Next_state] = min(dp[i + 1][Next_state], dp[i][j] + val); } else if(bit == 1) { int val = sum0[n - j] + tot1 - sum1[n - j], Next_state = n - sum0[n - j]; dp[i + 1][Next_state] = min(dp[i + 1][Next_state], dp[i][j] + val); } } } } printf("%lld\n", dp[62][0]); }
官方题解有一些证明,代码也比较易懂。