random.choices 函数时间复杂度
random.choices 函数
python 官方标准库 random 中,有个函数 random.choices(population, weights=None, *, cum_weights=None, k=1)
,比起常用的 random.choice(seq)
,这个函数可以指定概率权重和选择次数。
因为刷题的时候用到了这个函数,题目又对时间复杂度有限制,我就很好奇,然后来分析一下这个函数的时间复杂度。
源码
def choices(self, population, weights=None, *, cum_weights=None, k=1):
"""Return a k sized list of population elements chosen with replacement.
If the relative weights or cumulative weights are not specified,
the selections are made with equal probability.
"""
random = self.random
n = len(population)
if cum_weights is None:
if weights is None:
floor = _floor
n += 0.0 # convert to float for a small speed improvement
return [population[floor(random() * n)] for i in _repeat(None, k)]
cum_weights = list(_accumulate(weights))
elif weights is not None:
raise TypeError('Cannot specify both weights and cumulative weights')
if len(cum_weights) != n:
raise ValueError('The number of weights does not match the population')
total = cum_weights[-1] + 0.0 # convert to float
if total <= 0.0:
raise ValueError('Total of weights must be greater than zero')
bisect = _bisect
hi = n - 1
return [population[bisect(cum_weights, random() * total, 0, hi)]
for i in _repeat(None, k)]
参数说明
population
: 输入的待选取序列weights
: 权重序列cum_weights
: 累加的权重序列,相当于weights
的前缀和数组k
: 选取的次数,该函数会返回一个长度为k
的列表
功能说明
参考官方文档可知,这个函数通过权重随机选取数字,比如 choices([1, 2], weights=[3, 2])
,相当于使用 choice([1, 1, 1, 2, 2])
,也可以写成 choices([1, 2], cum_weights=[3, 5])
假设给出了权重(weights
)但是没有累加权重(cum_weights
):
- 函数内部会把权重累加
cum_weights = list(_accumulate(weights))
; - 使用
random()
函数输出一个[0.0, 1.0)
区间的数,乘上所有权重的累加和,作为生成的随机数。权重的累加和也是cum_weights
数组最后一个元素值; - 用二分查找 (标准库函数:bisect) 在累加序列
cum_weights
中找到随机数的位置,输出该位置的数据。
时间复杂度分析
函数共有 2 个出口:
-
weights
和cum_weights
均为None
的情况:return [population[floor(random() * n)] for i in _repeat(None, k)]
时间复杂度:O(k) ,因为
k
为常数,所以也可以认为时间复杂度为 O(1)这种情况和直接使用
choice
没有差别,所以我就不考虑在最终结果里了。 -
weights
不为None
的情况:return [population[bisect(cum_weights, random() * total, 0, hi)] for i in _repeat(None, k)]
时间复杂度:O(klog(n)),因为
k
为常数,所以也可以认为时间复杂度为 O(log(n)) (注:log(n) 来自二分查找)- 如果
cum_weights
为None
,还需要执行cum_weights = list(_accumulate(weights))
,_accumulate
类似于itertools.accumulate()
,时间复杂度:O(n),与上面的 O(log(n)) 叠加,总时间复杂度为:O(n)
- 如果
所以结论在于用户有没有给出累加权重,也就是 cum_weights
数组:
- 如果给出
cum_weights
:O(log(n)) ,精确一点就是 O(klog(n)) ,这个k
就是那个参数k
,是个常数。 - 如果没有给出:O(n)
所以呢,如果数据规模特别大,还是要谨慎使用这个函数的,尤其是没有提供 cum_weights
参数的时候。