如何实现:从0到n-1中随机等概率输出m个不同的数?

275 阅读5分钟
原文链接: zhuanlan.zhihu.com

有哪些看似简单其实非常精妙的代码? - 知乎

看到上面这个回答是起因,当然我们会用Python来写这个算法。答主给出了一个复杂度为O(n)的算法,这个算法在n非常非常巨大,比如说高达100亿的时候,显然并不是很有效。评论中也有人给出了另一种正确的做法,但也不是在所有时候都比较有效。

首先先说几种常见的解法:

  1. 循环m次,每次从0到n-1里取出一个随机数(显然可能出现重复因此是错误的)
  2. 保存每次取出的数,如果取到重复则重取,直到输出m个数(当m接近于n的时候,最后几次会退化得比较严重,平均复杂度O(mlogm))

如果你熟悉Python系统库的话,当然会立即想到,这个问题其实只需要

from random import sample

sample(xrange(n), m)

(Python3将xrange换成range)

就可以了。当然我们不能就把它当作答案,但我们可以参考下系统库的实现:

def sample(self, population, k):
        """Chooses k unique random elements from a population sequence or set.
        Returns a new list containing elements from the population while
        leaving the original population unchanged.  The resulting list is
        in selection order so that all sub-slices will also be valid random
        samples.  This allows raffle winners (the sample) to be partitioned
        into grand prize and second place winners (the subslices).
        Members of the population need not be hashable or unique.  If the
        population contains repeats, then each occurrence is a possible
        selection in the sample.
        To choose a sample in a range of integers, use range as an argument.
        This is especially fast and space efficient for sampling from a
        large population:   sample(range(10000000), 60)
        """

        # Sampling without replacement entails tracking either potential
        # selections (the pool) in a list or previous selections in a set.

        # When the number of selections is small compared to the
        # population, then tracking selections is efficient, requiring
        # only a small set and an occasional reselection.  For
        # a larger number of selections, the pool tracking method is
        # preferred since the list takes less space than the
        # set and it doesn't suffer from frequent reselections.

        if isinstance(population, _Set):
            population = tuple(population)
        if not isinstance(population, _Sequence):
            raise TypeError("Population must be a sequence or set.  For dicts, use list(d).")
        randbelow = self._randbelow
        n = len(population)
        if not 0 <= k <= n:
            raise ValueError("Sample larger than population or is negative")
        result = [None] * k
        setsize = 21        # size of a small set minus size of an empty list
        if k > 5:
            setsize += 4 ** _ceil(_log(k * 3, 4)) # table size for big sets
        if n <= setsize:
            # An n-length list is smaller than a k-length set
            pool = list(population)
            for i in range(k):         # invariant:  non-selected at [0,n-i)
                j = randbelow(n-i)
                result[i] = pool[j]
                pool[j] = pool[n-i-1]   # move non-selected item into vacancy
        else:
            selected = set()
            selected_add = selected.add
            for i in range(k):
                j = randbelow(n)
                while j in selected:
                    j = randbelow(n)
                selected_add(j)
                result[i] = population[j]
return result

系统库的sample分成了两部分,当k远比n小,而n本身比较大的时候,会使用前面的第二种方法,这是考虑到了空间使用量的问题:使用一个set保存少数几个元素在内存使用量上比较经济。否则会采用另一种方法。要解释另一种方法,我们先来看看经典的洗牌算法,也就是将列表中的n个元素随机打乱顺序的算法,也就是Python中的shuffle:

def shuffle(self, x, random=None):
        """Shuffle list x in place, and return None.
        Optional argument random is a 0-argument function returning a
        random float in [0.0, 1.0); if it is the default None, the
        standard random.random will be used.
        """

        if random is None:
            randbelow = self._randbelow
            for i in reversed(range(1, len(x))):
                # pick an element in x[:i+1] with which to exchange x[i]
                j = randbelow(i+1)
                x[i], x[j] = x[j], x[i]
        else:
            _int = int
            for i in reversed(range(1, len(x))):
                # pick an element in x[:i+1] with which to exchange x[i]
                j = _int(random() * (i+1))
                x[i], x[j] = x[j], x[i]

实际的实现看前一部分就行了。原理非常简单:从列表的末位开始往前处理,每次从剩下的长度中挑选一个,和最后一个元素交换。很容易证明每个元素在交换之后都均匀分布到了每个位置上。

那很显然,只要这个步骤只进行m步,就是我们要求的取出m个不重复数了,这也就是Python中sample的前一半的实现。


Python的实现是用两个算法拼起来的,有没有用一个算法同时搞定两种情况,而且时间复杂度和空间复杂度都是O(m)的方法呢?让我们来借鉴一下世界上最好的语言PHP,它的数组实现基本上来说是个hashmap(所以它那么慢),因此下标可以不连续。这就启发我们,我们可以在shuffle算法的基础上,用一个hashmap(也就是dict)来替代数组,实现在n非常大而m很小的情况下,只用很少量的内存空间来保存选出的元素。于是就有了下面的算法:

from random import randrange
from itertools import islice

def sample_generator(n):
    pool = {}
    for i in xrange(n):         # invariant:  non-selected at [0,n-i)
        j = randrange(n-i)
        result = pool.get(j, j)
        pool[j] = pool.get(n - i - 1, n - i - 1)
        yield result

def sample(n, m):
    return list(islice(sample_generator(n), 0, m))

我们用一个dict的pool来代替list,当某个下标在dict中不存在的时候,默认值与下标相等。注意我们使用一个生成器来代替了固定的m次循环,在需要m次循环的时候使用islice来实现,这样可以实现按需地取出随机数,需要多少取多少,很方便。

这个算法时间和空间复杂度都是固定的O(m),不过因为list的效率比dict略高,因此当n很小的时候,时间和空间的使用量上可能不如list的版本。