水塘抽样算法原理与Demo实现

1,071 阅读4分钟

问题引入

考虑这样一个问题:现在有一个很大的数据集,总数为n,但这个n未知且可能很大,以致于没法一次全部加载到内存,如何从数据集中随机的选出一个数据项?

初步解析

当总数n为一个较小规模的值时,我们可以把数据集加载到内存的一个数组里,然后使用random(n)得到一个随机数,根据这个随机数取具体的第几个数据项,这样可以实现随机地获得一个数据项。

但当n为一个很大的数值没有办法一次性把数据加载到内存时,就不能使用上述方法了。可以想到的一种方式是,首先遍历数据集,统计数据项的总数n,然后random(n)得到一个随机数m,然后再遍历一次数据集,返回第m个数据项。这种方式遍历了两遍数据集,还是不够好,有什么办法可以只遍历一次数据集就可以实现随机返回一个数据项吗?

引入算法

要说的水塘抽样算法就是做这个事情的。我们把数据集抽象成一个长度很长且未知的链表,然后我们可以新建一个链表节点ansNode存储待返回随机节点的备选节点,然后让我开始遍历链表。

水塘抽样算法是这样的一个步骤:

第0步,未开始遍历链表,备选随机节点ansNode为空;

第1步,开始遍历第一个节点node1,此时只有一个节点,所以备选节点为node1;

第2步,开始遍历第二个节点node2,此时要更新备选节点,在当前备选节点ansNode和当前遍历节点node2中选取一个,1/2的概率让node2成为新的ansNode,1/2的概率保留原ansNode;

第三步,开始遍历第三个节点node3,1/3的概率让node3成为新的ansNode,2/3的概率保留原ansNode;

……

第i步,开始遍历第i个节点node(i),1i\frac{1}{i}的概率让node(i)成为新的ansNode,i1i\frac{i-1}{i}的概率保留原ansNode;

……

第n步,开始遍历第n个节点noden,1n\frac{1}{n}的概率让noden成为新的ansNode,n1n\frac{n-1}{n}的概率保留原ansNode;

返回ansNode。

证明其正确性

ii可以为1~n中的任意一个数,

访问到第ii个节点时,其作为备选节点ansNode留下来的概率为1i\frac{1}{i}

访问到第i+1i+1个节点时,第ii个节点作为备选节点ansNode留下来的概率为1iii+1=1i+1\frac {1}{i} * \frac{i}{i+ 1}=\frac{1}{i+1}

访问到第i+2i+2个节点时,第ii个节点作为备选节点ansNode留下来的概率为1i+1i+1i+2=1i+2\frac{1}{i+1} * \frac{i+1}{i+2}=\frac{1}{i+2}

……

访问到第n1n-1个节点时,第ii个节点作为备选节点ansNode留下来的概率为1n1\frac{1}{n-1}

访问到第nn个节点时,第ii个节点作为备选节点ansNode留下来的概率为1n\frac{1}{n}

问题扩展

上述问题是在nn个数据项里随机地选取1个数据项,如果在nn个数据项里随机选取kk个数据项,怎么解决呢?

同样地,首先初始化一个长度为kk备选数组

Node[] ansNodes = new Node[k];

第0步,将前k个数据项作为备选项填充到数组中;

第1步,访问第k+1k+1个数据项,kk+1\frac{k}{k+1}的概率留下该数据项,替换当前备选数据项中任意一个的概率都为kk+1k=1k+1{\frac{k}{k+1}\above{2pt} {k}}=\frac{1}{k+1},即对前k个数据项每个数据项留下来的概率为11k+1=kk+11-\frac{1}{k+1}=\frac{k}{k+1},所以对于前k+1k+1个数据项中每个而言作为备选项留下来的概率都为kk+1\frac{k}{k+1}

第2步,访问第k+2个数据项,kk+2\frac{k}{k+2}的的概率留下该数据项,替换当前k个备选数据项中任意一个的概率都为kk+2k=1k+2{\frac{k}{k+2}\above{2pt} {k}}=\frac{1}{k+2},即对前k+1个数据项每个数据项留下来的概率为(11k+2)kk+1=kk+2(1-\frac{1}{k+2})*\frac{k}{k+1}=\frac{k}{k+2},所以对前k+2个数据项中每个而言作为备选项留下来的概率都为kk+2\frac{k}{k+2}

……

依次类推,访问第nn个数据项,可以保持nn个数据项都有kn\frac{k}{n}的概率被返回。

Demo实现

//随机返回一个节点
public ListNode getRandom(ListNode head){
    ListNode ansNode = head;
    int pos = 1;
    ListNode cur = head.next;
    while (cur != null) {
        pos++;
        int randomNum = random.nextInt(pos);
        if (randomNum == pos - 1) {
            ansNode = cur;
        }
        cur = cur.next;
    }
    return ansNode;
}



//随机返回k个节点
public ListNode[] getKRandom(ListNode head, int k) {
    ListNode[] ansNodes = new ListNode[k];
    ListNode cur = head;
    int pos = 1;
    for (; pos <= k; ++pos) {
        ansNodes[pos - 1] = cur;
        cur = cur.next;
    }
    while (cur != null) {
        int randomNum = random.nextInt(pos);
        if (randomNum < k) {
            ansNodes[randomNum] = cur;
        }
        cur = cur.next;
        pos++;
    }
    return ansNodes;
}