蓄水池抽样算法

183 阅读5分钟

在刷leetcode的时候,学习了蓄水池抽样算法。索性将leetcode给出的4道题研究一遍。

image.png

算法思想

算法的基本思想是,在遍历数据流的过程中,保持一个大小固定的蓄水池,初始时蓄水池为空。对于第i个元素,以1/i的概率选择将其放入蓄水池中,并以1-1/i的概率选择保留蓄水池中的元素。

通过这种方式,蓄水池抽样算法能够保证每个元素被选择的概率相等,并且能够有效地处理大规模的数据流。当遍历完整个数据流后,蓄水池中的元素就是随机抽样得到的样本。

蓄水池抽样算法在很多应用中都有广泛的应用,比如随机抽样调查、网络算法等。它的时间复杂度为O(n),其中n为数据流的大小。

以下面的题目为例:

leetcode 382.链表随机节点

给你一个单链表,随机选择链表的一个节点,并返回相应的节点值。每个节点 被选中的概率一样 . 实现 Solution 类: Solution(ListNode head) 使用整数数组初始化对象。 int getRandom() 从链表中随机选择一个节点并返回该节点的值。链表中所有节点被选中的概率相等。  

从所有样本中抽取若干个,要求每个样本被抽到的概率相等。

具体做法为:从前往后处理每个样本,每个样本成为答案的概率为 1/i,其中 i 为样本编号.最终可以确保每个样本成为答案的概率均为 1/n;

证明过程

蓄水池抽样算法的正确性可以通过数学归纳法来证明。

假设已经遍历了前i个元素,且每个元素被选择的概率相等,为1/i。现在考虑第i+1个元素被选择的概率。

如果第i+1个元素不在蓄水池中,那么它被选择的概率为(1-1/(i+1)) * 1/i = 1/(i+1),满足每个元素被选择的概率相等。

如果第i+1个元素在蓄水池中,那么它被选中的概率为1/(i+1) * k/i,其中k为蓄水池中元素的个数。因为每个元素被选择的概率相等,所以第i+1个元素被选中的概率为k/(i+1)。而蓄水池中每个元素被保留的概率为(1-1/i),所以k个元素都被保留的概率为(1-1/(i+1))^k。因此,第i+1个元素被选中且蓄水池中所有元素都被保留的概率为1/(i+1) * k/i * (1-1/(i+1))^k。

现在需要证明,对于任意k,上式等于k/(i+1) * (i/(i+1))^k。这可以通过将(1-1/(i+1))^k展开并化简得到。

因此,第i+1个元素被选择的概率为1/(i+1) * k/i * (1-1/(i+1))^k = k/(i+1) * (i/(i+1))^k,满足每个元素被选择的概率相等。

综上所述,当遍历完整个数据流后,蓄水池中的元素就是随机抽样得到的样本。

class Solution {
public:
    ListNode * h;
    Solution(ListNode* head) {
        h=head;
    }
    
    int getRandom() {
        int n=0,c=-1;
        for(auto p=h;p;p=p->next)
        {
            n++;
            if(rand()%n==0) c=p->val;
        }
        return c;
    }
};

398.随机数索引

分析题目: 用一个哈希表存储相同数字的所有下标 询问的时候直接输出该数字所对应数组里的一个随机值。

class Solution {
public:
    unordered_map<int,vector<int> > mp; //一个数字对应一个下标数组
    Solution(vector<int>& nums) {
        for(int i=0;i<nums.size();i++)
        {
            mp[nums[i]].push_back(i);
        }
    }

    int pick(int target) {
        return mp[target][rand()%mp[target].size()];    //返回目标值的下标数组中的随机值
    }
};

497. 非重叠矩形中的随机点

分析: 用前缀和算出所有矩形面积的累加,在0-最大面积和范围产生随机数,看当前随机数落在哪个矩形的前缀和区间内,

这样选定矩形,确定矩形是时候使用二分快速确定,确定好当前矩形后,在当前矩形区域生成随机数

矩形面积越大,前缀和所占面积越大,被选中的概率越大

class Solution {
public:
    vector<vector<int>> rects;
    int n;
    vector<int> s;


    Solution(vector<vector<int>>& _rects) {
        rects=_rects;
        n=rects.size();
        s.push_back(0);
        for(auto &r:rects)
        {
            int dx=r[2]-r[0]+1,dy=r[3]-r[1]+1;
            s.push_back(s.back()+dx*dy);
        }

    }
// 二分
    vector<int> pick() {
        int k=rand()%s.back()+1;       // 随机数
        int l=1,r=n;
        while(l<r)
        {
            int mid=l+r>>1;
            if(s[mid]>=k) r=mid;
            else l=mid+1;
        }
        auto t=rects[r-1];
        int dx=t[2]-t[0]+1,dy=t[3]-t[1]+1;
        return {rand()%dx+t[0],rand()%dy+t[1]};

    }
};

/**
 * Your Solution object will be instantiated and called as such:
 * Solution* obj = new Solution(rects);
 * vector<int> param_1 = obj->pick();
 */

519. 随机翻转矩阵

我们可以用一个哈希表记录所有1的位置(只用来记录特殊的位置)。为了方便,我们把二维的位置映射到一个int变量,哈希表中只记录这个int变量。

比如,如果位置是第r行第c列,那么我们得到:pos = r * cols + c,其中cols是二维矩阵的列数(即n_cols), pos是将{r, c}映射到一维空间的下标。

随机翻转一个0为1的时候,我们只要找出一个不为1(即哈希表中没记录过的)位置pos = rand() % capacity,将这个位置记录到哈希表中(其中capacity是矩阵的大小,即n_rows * n_cols), 并且求出这个位置对应的二维矩阵的行号: r = pos / cols, 列号: c = pos % cols,将{r, c}作为filp()函数的返回值即可。

reset()方法只需要将哈希表clear即可,表示当前没有记录矩阵中的任何位置的值为1。

using LL= long long;
class Solution {
public:
    unordered_set<LL> hash;
    int row,col,capacity;
    Solution(int m, int n) {
        row=m,col=n;
        capacity=m*n;
    }

    vector<int> flip() {
        LL r,c,pos;
        do
        {
            pos=rand()%capacity;
        }while(hash.count(pos));    // 找到不是1的位置,变为0

        hash.insert(pos);
        r=pos/col;
        c=pos%col;
        return {(int)r,(int)c};
    }

    void reset() {
        hash.clear();
    }
};