【题解】A19337.火星背包

\(\bf{用 CDQ 分治可以极大地提升程序运行的速度。}\)
\(\bf{实测在本数据量下,可以在 \color{red}10ms\color{normal}} 内通过所有的测试点!\)

关于折半搜索的内容可以参考这两篇题解:

火星背包 - アイドル
火星背包|折半搜索 - Macw

思路分析

这道题是一道超大背包的典型问题,标准做法应该是使用 折半搜索。搜索两次再将两次的答案分别记录下来,最后拼凑出一个正确的答案。但对于超大背包类型问题,在折半搜索的过程需要记录所有的可以达到的点的状态(选和不选的组合)。一个显而易见的问题就是,这会生成许多 又重又不值钱 的状态,然而我们根本就不需要用这些无效状态。

一个可行的做法就是通过分治做法,基于普通折半搜索的基础上加上分治优化,不断地将区间缩小到原来的二分之一,在每一层合并的时候就可以提前先把那些无效状态删除,防止在后续的合并中被使用。

经过测试,在一般数据下,分治可以在几毫秒之内完成。但在极限数据下(即没有任何的无效数据),程序的运行速度相较于 std 会慢一些。

时间复杂度

本算法的渐进时间复杂度与折半搜索的时间复杂度相同,但是常数比较小。

代码解释

这个算法的关键在于利用 cdq 分治的思想,在每一步中合并左右两边的结果,并通过二分查找找到最优解。这样可以在较短的时间内得到问题的解决方案,尤其适用于处理大规模数据。

CDQ 分治版本

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#define int long long
using namespace std;

int n, m, path;
// 用于存放每一个物品,w是物品的重量,v是物品的价值。
struct obj{
    int w;
    int v;
} arr[55];
struct node{
    int weight;  // 状态所需要的容量
    int value;  // 状态的价值
    int id;  // 记录到达该状态的路径。
    /*
        例如有五个物品:
        a b c d e
        选择 a c e 三个物品,可以用二进制表示:10101
        id 变量用于存储 背包组合(“10101”) 的二进制形式。
    */
}; 
// res[i] 表示从起始位置为i的区间的所有可能方案。
vector<node> res[55], combo;

// 按照weight进行排序,weight相同按照value从小到大排序。
bool cmp(node a, node b){
    if (a.weight != b.weight) return a.weight < b.weight;
    return a.value < b.value;
}

// cdq分治应该是可以过的吧
int cdq(int l, int r){
    // 只剩下一个了,直接返回结果,不需要继续递归下去了。
    if (l == r){
        res[l].clear();
        res[l].push_back((node){0, 0, 0});
        if (arr[l].w <= m)
            res[l].push_back((node){arr[l].w, arr[l].v, 1LL << l});
        if (l == 0 && r == n-1) {
            path = 1;
            return res[l][res[l].size()-1].value;
        };
        return 0;
    }
    // 继续cdq分治。
    int mid = (l + r) >> 1;
    cdq(l, mid); cdq(mid+1, r);

    // 计算结果,最终将左右两边结果合并起来
    if (n == r - l + 1){
        int ans = 0;  // 记录最终答案。
        // 右边的答案。
        auto &right = res[mid + 1];
        for (int i=0; i<res[l].size(); i++){
            int L = 0, R = right.size() - 1;
            while(L <= R){
                int amid = (L + R) >> 1;
                if (res[l][i].weight + right[amid].weight <= m) L = amid + 1;
                else R = amid - 1;
            }
            if (res[l][i].value + right[L - 1].value > ans){
                ans = res[l][i].value + right[L - 1].value;
                path = res[l][i].id + right[L-1].id;
            }
            ans = max(ans, res[l][i].value + right[L - 1].value);
        }
        return ans;
    }

    // 合并左右两半部分区间,看一下在限度内的更佳组合。
    // 归并的核心思想,将左右两边结果合并成大的结果。
    for (int i=0; i<res[l].size(); i++){
        for (int j=0; j<res[mid+1].size(); j++){
            if (res[l][i].weight + res[mid+1][j].weight <= m){
                int s1 = res[l][i].weight + res[mid+1][j].weight;
                int s2 = res[l][i].value + res[mid+1][j].value;
                // 合并左右两种方案的路径总和。
                // 这里使用了二进制的思想,0表示不选,1表示选。
                int s3 = res[l][i].id + res[mid+1][j].id;
                combo.push_back((node){s1, s2, s3});
            } else break;
        }
    }
    // 排序,依照cmp中定义的规则排序。   
    sort(combo.begin(), combo.end(), cmp);
    res[l].clear();
    int mi_v = -1;
    for (int i=0; i<combo.size(); i++){
        if (combo[i].value > mi_v){
            mi_v = combo[i].value;
            res[l].push_back(combo[i]);
        }
    }
    combo.clear();
    return 0;
}

signed main(){
    // 加快输入输出,关闭同步流。
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m;
    for (int i=0; i<n; i++) 
        cin >> arr[i].w >> arr[i].v;

    // 运行分治算法。
    cout << cdq(0, n-1) << " ";
    // 输出答案。
    vector<int> final_result;
    for (int i=0; i<n; i++){
        if (path >> i & 1){
            final_result.push_back(i+1);
        }
    }
    cout << final_result.size() << endl;
    for (auto i : final_result) cout << i << " "; 
    return 0;
}

折半搜索版本

// 不得不说,MInM的代码是真的长。写起来也好麻烦。
#include <iostream>
#include <algorithm>
#include <vector>
#define int long long
using namespace std;

int n, m, mid;
int w[55], v[55];
int maxa, maxb;

// node 用于记录dfs可以组合出来的所有状态。
struct node{
    int weight;  // 状态所需要的容量
    int value;  // 状态的价值
    int id;  // 记录到达该状态的路径。
    /*
        例如有五个物品:
        a b c d e
        选择 a c e 三个物品,可以用二进制表示:10101
        id 变量用于存储 背包组合(“10101”) 的二进制形式。
    */
}; 
// 用于存储两次dfs所有可以到达的节点。
vector<node> ans1, ans2;

// 按照weight进行排序,weight相同按照value从小到大排序。
bool cmp(node a, node b){
    if (a.weight == b.weight)
        return a.value < b.value;
    return a.weight < b.weight;
}

// 第一次深度优先搜索。
void dfs1(int L, int R, int weight, int value, int id){
    if (weight > m) return ;
    if (L > R){
        // 寻找是否有更优的,没有的话就
        ans1.push_back((node){weight, value, id});
        return ;
    }
    // 两种选择,选物品或者不选物品。
    dfs1(L+1, R, weight, value, id);
    // 这里的位运算可以自己动手画一下。
    dfs1(L+1, R, weight + w[L], value + v[L], (1LL << (L-1LL)) + id);
    return ;
}

// 第二次深度优先搜索。
void dfs2(int L, int R, int weight, int value, int id){
    if (weight > m) return ;
    if (L > R){
        ans2.push_back((node){weight, value, id});
        return ;
    }
    // 两种选择,选物品或者不选物品。
    dfs2(L+1, R, weight, value, id);
    dfs2(L+1, R, weight + w[L], value + v[L], (1LL << (L-1LL)) + id);
    return ;
}

signed main(){
    // 加快输入输出,关闭同步流。
    ios::sync_with_stdio(0); 
    cin.tie(0); cout.tie(0);
    cin >> n >> m;
    for (int i=1; i<=n; i++)
        cin >> w[i] >> v[i];

    // 折半搜索
    mid = n >> 1;
    dfs1(1, mid, 0, 0, 0);
    dfs2(mid + 1, n, 0, 0, 0);

    // 按照规则排序。
    sort(ans1.begin(), ans1.end(), cmp);

    // 表示截至目前得到的最大value。
    int value = 0;  

    // 将ans1中的无效元素清除(非最优解)
    vector<node> tmpath;
    tmpath.push_back((node){-1, -1, -1});
    for (int i=0; i<ans1.size(); i++){
        if (tmpath.back().value < ans1[i].value){
            tmpath.push_back(ans1[i]);
        }
    }

    // 二分拼接前后两半部分的答案。
    int maximum = 0, res = 0;
    for (int i=0; i<ans2.size(); i++){
        int weight = m - ans2[i].weight;
        // 寻找最后的,满足 m - weight
        int l = 1, r = tmpath.size() - 1;
        int ans = 0;
        // 寻找最优解,用二分优化。
        while(l <= r){
            int mid = (l + r) >> 1;
            if (tmpath[mid].weight <= weight){
                l = mid + 1;
                ans = mid;
            } else r = mid - 1;
        }
        // 更新答案
        if (ans != 0 && tmpath[ans].value + ans2[i].value > maximum){
            maximum = tmpath[ans].value + ans2[i].value;
            // 将前后两半部分的路径相加,就可以获得最终的路径。
            // 详情见二进制的加减运算。
            res = (ans2[i].id) + (tmpath[ans].id);
        }
    }
    
    // 结果输出
    vector<int> final_result;
    for (int i=0; i<n; i++){
        if (res >> i & 1){
            final_result.push_back(i+1);
        }
    }

    cout << maximum << " " << final_result.size() << endl;
    for (auto i : final_result) cout << i << " "; 
    return 0;
}
posted @ 2024-04-16 15:30  Macw  阅读(10)  评论(0编辑  收藏  举报