水塘抽样问题 与 洗牌算法

水塘抽样 与 洗牌算法

本文介绍两个相似的问题,水塘抽样和洗牌算法。

水塘抽样(Reservoir Sampling)

水塘抽样(Reservoir Sampling)说的是这样一个问题:当内存无法完全加载时,如何从数据流或大数据集中随机选取k个样本,并保证每个样本被选取的概率相等。

典型问题出现在高纳德《计算机编程艺术》和谷歌面试题中都出现过:可否在一未知大小的集合中,随机取出一元素?/ 在不知道文件总行数的情况下,如何从文件中随机的抽取一行?

这个问题我们可以分两种情况进行讨论:

  • k == 1
  • k > 1

当 k == 1 的时候,我们可以在每次遇到合法对象时,以 1/n 的概率决定是否替换结果对象,其中 n 是当前遇到过的合法对象数目。显然,第一次遇到的时候肯定会替换成结果对象;第二次遇到的时候有一半可能替换,也就是前两个合法对象都有一半可能返回结果;第三次遇到的时候有 1/3 可能替换,前两个合法对象返回的可能都出自剩下这 2/3 可能,从而前三个合法对象返回的概率也一样 …… 归纳法可证所有合法对象返回的概率都相同,且概率总和为 1。

当 k > 1 的时候,我们在前 k 次遇到合法对象的时候直接存入结果数组;之后每一次遇到合法对象,都以 k/n 的可能来替换结果对象,结果数组中每个对象都等概率分到 1/k 份替换概率。参照上面的证明,归纳法可证所有合法对象返回的概率都相同,且概率总和为 k。

代码模板

// k == 1
int cnt = 0;
for (int i=0; i<arr.size(); i++) {
    if (arr[i] != target) continue;
    if ((rand() % ++cnt) == 0) res = i;
}

// k > 1
int cnt = 0;
for (int i=0; i<arr.size(); i++) {
    if (arr[i] != target) continue;
    if (cnt < k) {
        res[cnt++] = i;
    } else {
        int j = (rand() % ++cnt);
        if (j < k) res[j] = i;
    }
}

下面给出几道例题:

LeetCode 398. Random Pick Index

保证指定的target一定会出现在数组中,可能出现多次,要求等概率给出其中一个合法下标。同时提示考虑内存较小的实际情况。
这就是典型的水塘抽样问题了。

/*
 * @lc app=leetcode id=398 lang=cpp
 *
 * [398] Random Pick Index
 */

// @lc code=start
/*
class Solution {
    unordered_map<int, vector<int>> val2idx;
public:
    Solution(vector<int>& nums) {
        for (int i=0; i<nums.size(); i++) {
            val2idx[nums[i]].push_back(i);
        }
    }

    int pick(int target) {
        assert(!val2idx[target].empty());
        auto&& v = val2idx[target];
        return v[rand() % v.size()];
    }
}; // AC, O(N) space, O(1) time
*/

// The array size can be very large.
// Don't use too much extra space.
class Solution {
    vector<int> arr;
public:
    Solution(vector<int>& nums) : arr(std::move(nums)) {}

    int pick(int target) {
        int cnt = 0;
        int res = -1;
        for (int i=0; i<arr.size(); i++) {
            if (arr[i] != target) continue;
            if ((rand() % ++cnt) == 0) {
                res = i;
            }
        }
        return res;
    } // AC, O(1) space, O(N) time
};
/**
 * Your Solution object will be instantiated and called as such:
 * Solution* obj = new Solution(nums);
 * int param_1 = obj->pick(target);
 */
// @lc code=end

这道题程序检查不严谨,用 O(N) 的内存也可以通过。但是在实际的应用场景中,这是不应该的,就连保存整个数组也不应该允许,只能是 std::move 符合要求,但是这会让不会新标准C++的人很为难 ……

LeetCode 382. Linked List Random Node

这道题换成了从链表中随机返回一个结点的元素,本质上没有变化,只要能顺序访问元素就可以,算法本身不要求随机访问。

/*
 * @lc app=leetcode id=382 lang=cpp
 *
 * [382] Linked List Random Node
 */

// @lc code=start
/**
 * Definition for singly-linked list.
 * struct ListNode {
 *     int val;
 *     ListNode *next;
 *     ListNode() : val(0), next(nullptr) {}
 *     ListNode(int x) : val(x), next(nullptr) {}
 *     ListNode(int x, ListNode *next) : val(x), next(next) {}
 * };
 */
class Solution {
public:
    /** @param head The linked list's head.
        Note that the head is guaranteed to be not null, so it contains at least one node. */
    Solution(ListNode* head) {
        ptr = head;
    }
    
    /** Returns a random node's value. */
    int getRandom() {
        int res = -1;
        size_t n = 0;
        ListNode* p = ptr;
        while (p) {
            if ((rand() % ++n) == 0) {
                res = p->val;
            }
            p = p->next;
        }
        return res;
    }
private:
    ListNode* ptr;
}; // AC, Reservoir Sampling

/**
 * Your Solution object will be instantiated and called as such:
 * Solution* obj = new Solution(head);
 * int param_1 = obj->getRandom();
 */
// @lc code=end

洗牌算法

高端的洗牌,往往只需要最简单的一遍扫描。(_)

大名鼎鼎的 Knuth shuffle,思想与前面的水塘抽样有些类似。不过需要注意两点:

  1. 从后往前扫,这样每个元素至多被交换一次,是确定性的;
  2. 交换时,随机选取的交换对象包括自身原本所在位置,保证每个元素都能取到每个位置。

概率的证明方法类似,也可用归纳法:
第一次交换,任何元素出现在最后一个位置的概率是 1/n;
第二次交换,任何元素出现在倒数第二个位置的概率是 (1-1/n) * 1/(n-1) = 1/n;
...
直到就剩一个元素不用交换。

代码模板

for (int i=arr.size()-1; i>0; i--) {
    std::swap(arr[i], arr[rand() % (i+1)]);
}

LeetCode 384. Shuffle an Array

我们可以考虑手写洗牌算法,也可以直接调用 std::random_shuffle 来洗牌。或者,我们可以自定义生成器来使用 std::shuffle 洗牌,这个写起来更复杂一些。

/*
 * @lc app=leetcode id=384 lang=cpp
 *
 * [384] Shuffle an Array
 */

// @lc code=start
class Solution {
public:
    Solution(vector<int>& nums) : src(std::move(nums)) {}
    
    /** Resets the array to its original configuration and return it. */
    vector<int> reset() {
        return src;
    }
    
    /** Returns a random shuffling of the array. */
    vector<int> shuffle() {
        if (src.empty()) return src;

        vector<int> tmp(src);
        // std::random_shuffle(tmp.begin(), tmp.end()); // AC
        for (int i=tmp.size()-1; i>0; i--) {
            std::swap(tmp[i], tmp[rand() % (i+1)]);
        } // AC
        return tmp;
    }
private:
    const vector<int> src;
};

/**
 * Your Solution object will be instantiated and called as such:
 * Solution* obj = new Solution(nums);
 * vector<int> param_1 = obj->reset();
 * vector<int> param_2 = obj->shuffle();
 */
// @lc code=end
posted @ 2021-03-12 15:53  与MPI做斗争  阅读(158)  评论(0编辑  收藏  举报