关于乱序(shuffle)与随机采样(sample)的一点探究
最近一个月的时间,基本上都在加班加点的写业务,在写代码的时候,也遇到了一个有趣的问题,值得记录一下。
简单来说,需求是从一个字典(python dict)中随机选出K个满足条件的key。代码如下(python2.7):
1 def choose_items(item_dict, K, filter): 2 '''item_dict = {id:info} ''' 3 candidate_ids = [id for id in item_dict if filter(item_dict[id])] 4 if len(candidate_ids) <= K: 5 return set(candidate_ids) 6 else: 7 return set(random.sample(candidate_ids, K))
代码逻辑很简单,也能正常工作。但我知道这个函数调用的频率会很高,len(item_dict)也会比较大,那么这段代码会不会有效率问题呢。当然,一切都要基于profile,如果确实有问题,那么就需要优化。但首先,我想搞明白的是,我使用了random.sample这个函数,这个函数的时间复杂度如何呢。另外,我也经常会使用random.shuffle函数,也就想一并搞清楚。
本文记录对shuffle,sample两个算法的思考,参考的是python2.7.3中的random模块。当然,这两个算法与语言并不相关。另外,本人对算法研究甚少,认识错误之处还请大家不吝赐教。
本文地址:http://www.cnblogs.com/xybaby/p/8280936.html
Shuffle
shuffle的意思就是让序列乱序,本质上就是让序列里面的每一个元素等概率的重新分布在序列的任何位置。在使用MP3听歌(是不是暴露的年龄)的时候,就有两个功能:shuffle,random,二者的区别在于,前者打乱播放顺序,保证所有的歌曲都会播放一遍;而后者每次随机选择一首。
Python里面random.shuffle源码如下:
1 def shuffle(self, x, random=None, int=int): 2 """x, random=random.random -> shuffle list x in place; return None. 3 4 Optional arg random is a 0-argument function returning a random 5 float in [0.0, 1.0); by default, the standard random.random. 6 """ 7 8 if random is None: 9 random = self.random 10 for i in reversed(xrange(1, len(x))): 11 # pick an element in x[:i+1] with which to exchange x[i] 12 j = int(random() * (i+1)) 13 x[i], x[j] = x[j], x[i]
核心的代码就3行,其实就是非常经典的Fisher–Yates shuffle算法的实现,Fisher–Yates shuffle算法伪码如下:
-- To shuffle an array a of n elements (indices 0..n-1): for i from n−1 downto 1 do j ← random integer such that 0 ≤ j ≤ i exchange a[j] and a[i]
第一步 即从0到N-1个元素中随机选择一个与第N-1个替换
第二步 从0到N-2个元素中随机选择一个与第N-2个替换
第k步 从0到N-k个元素中随机选择一个与第N-K个替换
要证明算法的正确性也很简单,即任何一个元素shuffle之后出现在任意位置的概率都是1/N。任意一个元素,放在第N-1个位置的概率是1/N, 放在pos N-2的位置是 (N-1)/N * 1 / (N-1) = 1/N 。需要注意的是,一个元素一旦被交换到了序列的尾部,那么就不会再被选中,这也是算法一目了然的原因。
上面的实现是从后到前的,当然也可以从前到后,即先从0到N-1个元素中随机选择一个与第0个交换,然后从1到N-1个元素中随机选择一个与第1个交换 。。。只不过写代码的时候会稍微麻烦一点点,wiki上也有相应的伪码。
但是,我也看到网上有这么一种实现:
1 void get_rand_number(int array[], int length) 2 { 3 int index; 4 int value; 5 int median; 6 7 if(NULL == array || 0 == length) 8 return ; 9 10 /* 每次发牌的时候任意分配待交换的数据 */ 11 for(index = 0; index < length; index ++){ 12 value = rand() % length; 13 14 median = array[index]; 15 array[index] = array[value]; 16 array[value] = median; 17 } 18 }
与Fisher–Yates shuffle算法的区别在于,上面的算法每次都是从整个序列中选择一个元素作为被交换的元素,即先从整个序列选择一个元素与第0个元素交换,然后再从整个序列选择一个元素与第1个元素交换.。。。这个直觉就有点问题,比如一个元素(X)第一步就放到了第0个位置,但是之后有可能被交换到其他位置,以后X就再也不会回到第0个元素,当然,X也可能再第二步 第三步被交换到第0个位置。
但要证明该算法有问题似乎不是这么容易,那么首先用事实(数据)说话,于是我用python重写了上述代码,并做了测试,代码如下
1 import random 2 def myshuffle(lst): 3 length = len(lst) 4 for idx in xrange(length): 5 t_idx = random.randint(0, length-1) 6 lst[idx], lst[t_idx] = lst[t_idx], lst[idx] 7 if __name__ == '__main__': 8 random.seed() 9 10 pre_lst = ['a', 'b', 'c'] 11 count = dict((e, {}) for e in pre_lst) 12 TRY = 1000000 13 14 for i in xrange(TRY): 15 lst = pre_lst[:] 16 myshuffle(lst) 17 for alpha in pre_lst: 18 idx = lst.index(alpha) 19 count[alpha][idx] = count[alpha].get(idx, 0) + 1 20 21 for alpha, alpha_count in sorted(count.iteritems(), key=lambda e: e[0]): 22 result_lst = [] 23 for k, v in sorted(alpha_count.iteritems(), key=lambda e: e[0]): 24 result_lst.append(round(v * 1.0 / TRY, 3)) 25 print alpha, result_lst
运算的结果是:
('a', [0.333, 0.334, 0.333])('b', [0.371, 0.296, 0.333])('c', [0.296, 0.37, 0.334])
如果将pre-list改成 pre_list = ['a', 'b', 'c', 'd', 'e'],那么输出结果是:
('a', [0.2, 0.2, 0.2, 0.2, 0.199])('b', [0.242, 0.18, 0.186, 0.191, 0.2])('c', [0.209, 0.23, 0.175, 0.186, 0.2])('d', [0.184, 0.205, 0.23, 0.18, 0.2])('e', [0.164, 0.184, 0.209, 0.242, 0.2])
这里稍微解释一下输出,每一行是字母在shuffle之后,出现在每一个位置的概率。比如元素‘e',在pre_list的位置是4(从0开始),shuffle之后,出现在第0个位置的统计概率为0.164,出现在第1个位置的统计概率是0.184,显然不是等概率的。
假设P[i][j]是原来序列种第i个元素shuffle之后移动到第j个位置的概率,那么这个公式怎么推导昵?我尝试过,不过没有推出来。
在stackoverflow上,我提问了这个问题,并没有得到直接的答案,不过有一个回答很有意思,指出了从理论上这个shuffle算法就不可能是正确的
This algorithm has
n^n
different ways to go through the loop (n
iterations picking one ofn
indexes randomly), each equally likely way through the loop producing one ofn!
possible permutations. Butn^n
is almost never evenly divisible byn!
. Therefore, this algorithm cannot produce an even distribution.
就是说,myshuffle由N^N种可能,但按照排队组合,N个元素由N的阶乘种排列方式。N^N不能整除N的阶乘,所以不可能是等概率的。
欢迎大家帮忙推倒这个公式,我自己只能推出P[N-1][0], P[N-2][0],真的头大。
Sample
Python中random.sample的document是这样的:
random.sample(population, k)
Return a k length list of unique elements chosen from the population sequence. Used for random sampling without replacement.
上面的document并不完整,不过也可以看出,是从序列(sequence)中随机选择k个元素,返回的是一个新的list,原来的序列不受影响。
但是从document中看不出时间复杂度问题。所以还是得看源码:
1 def sample(self, population, k): 2 """Chooses k unique random elements from a population sequence. 3 4 Returns a new list containing elements from the population while 5 leaving the original population unchanged. The resulting list is 6 in selection order so that all sub-slices will also be valid random 7 samples. This allows raffle winners (the sample) to be partitioned 8 into grand prize and second place winners (the subslices). 9 10 Members of the population need not be hashable or unique. If the 11 population contains repeats, then each occurrence is a possible 12 selection in the sample. 13 14 To choose a sample in a range of integers, use xrange as an argument. 15 This is especially fast and space efficient for sampling from a 16 large population: sample(xrange(10000000), 60) 17 """ 18 19 # Sampling without replacement entails tracking either potential 20 # selections (the pool) in a list or previous selections in a set. 21 22 # When the number of selections is small compared to the 23 # population, then tracking selections is efficient, requiring 24 # only a small set and an occasional reselection. For 25 # a larger number of selections, the pool tracking method is 26 # preferred since the list takes less space than the 27 # set and it doesn't suffer from frequent reselections. 28 29 n = len(population) 30 if not 0 <= k <= n: 31 raise ValueError("sample larger than population") 32 random = self.random 33 _int = int 34 result = [None] * k 35 setsize = 21 # size of a small set minus size of an empty list 36 if k > 5: 37 setsize += 4 ** _ceil(_log(k * 3, 4)) # table size for big sets 38 if n <= setsize or hasattr(population, "keys"): 39 # An n-length list is smaller than a k-length set, or this is a 40 # mapping type so the other algorithm wouldn't work. 41 pool = list(population) 42 for i in xrange(k): # invariant: non-selected at [0,n-i) 43 j = _int(random() * (n-i)) 44 result[i] = pool[j] 45 pool[j] = pool[n-i-1] # move non-selected item into vacancy 46 else: 47 try: 48 selected = set() 49 selected_add = selected.add 50 for i in xrange(k): 51 j = _int(random() * n) 52 while j in selected: 53 j = _int(random() * n) 54 selected_add(j) 55 result[i] = population[j] 56 except (TypeError, KeyError): # handle (at least) sets 57 if isinstance(population, list): 58 raise 59 return self.sample(tuple(population), k) 60 return result
咋眼一看,不同的情况下有两种方案(对应38行的if、46行的else),一种方案类似shuffle,复杂度是O(K);而另一种方案,看代码的话,复杂度是O(NlogN) (后面会说明,事实并非如此)
我当时就惊呆了,这个时间复杂度不能接受吧,在Time complexity of random.sample中,也有网友说是O(NlogN)。这个我是不能接受的,这个是官方模块,怎么可能这么不给力,那我自然想看明白这段代码。
Sample的各种实现
在这之前,不妨考虑一下,如果要自己实现这个sample函数,那么有哪些方法呢。
我们首先放宽sample的定义,就是从有N个元素的序列中随机取出K个元素,不考虑是否影响原序列
第一种,随机抽取且不放回
跟抽牌一样,随机从序列中取出一个元素,同时从原序列中删除,那么不难验证每个元素被取出的概率都是K/N(N是序列长度),满足Sample需求。
若不考虑元素从列表中删除的代价,那么时间复杂度是O(K)。但问题也很明显,就是会修改原序列
第二种,随机抽取且放回
除了记录所有被选择的元素,还需要维护被选择的元素在序列中的位置(selected_pos_set)。随机从序列中取出一个元素,如果抽取到的元素的位置在selected_pos_set中,那么重新抽取;否则将新元素的位置放到selected_pos_set中。
不难发现,这个就是python random.sample代码中第二种实现。
这个算法的好处在于,不影响原序列。
那么时间复杂度呢?在抽取第i个元素的时候,抽取到重复位置元素的概率是(i - 1)/N,那么平均抽取次数就是N/(N - i +1)。那么抽取K个元素的平均抽取测试就是,sum(N/(N - i +1) ), 1 <= i <= K; 等于N(logN - log(N-K+1)) 。当K等于N时,也就是NlogN
第三种,先shuffle整个序列,然后取前K个元素
算法的正确性很容易验证,时间复杂度是O(N),而且原序列会被修改(乱序也算做修改)
第四种,部分shuffle,得到K个元素就返回
如果了解shuffle算法,那么算法理解还是很容易的。random.sample中第一种方案也是这样的算法。
单独实现这个算法的话就是这个样子的:
1 def sample_with_shuffle(self, population, k): 2 n = len(population) 3 result = [None] * k 4 for i in xrange(k): # invariant: non-selected at [0,n-i) 5 j = int(random.random() * (n-i)) 6 result[i] = population[j] 7 population[j] = population[n-i-1] # move non-selected item into vacancy 8 return result
时间复杂度是O(K),但缺点就是原序列会被改变。
第五种,水塘抽样算法
水塘抽样算法(Reservoir_sampling)解决的是 样本总体很大,无法一次性放进内存;或者是在数据流上的随机采样问题。即不管有多少个元素,被选中的K个元素都是等概率的。算法很巧妙,也是非常经典的面试题。
算法伪码是这样的:
1 ReservoirSample(S[1..n], R[1..k]) 2 // fill the reservoir array 3 for i = 1 to k 4 R[i] := S[i] 5 6 // replace elements with gradually decreasing probability 7 for i = k+1 to n 8 j := random(1, i) // important: inclusive range 9 if j <= k 10 R[j] := S[i]
算法的时间复杂度是O(N),且不会影响原序列。
回到random.sample
通过上面的思考可见,最低复杂度是O(K),但需要改变原序列。如果不改变原序列,时间复杂度最低为O(N)。
但是如果重新拷贝一份原序列,那么是可以使用部分shuffle,但拷贝操作本身,需要时间与额外的空间。
其实python random.sample这个函数的注释说明了实现的原因:
# Sampling without replacement entails tracking either potential
# selections (the pool) in a list or previous selections in a set.# When the number of selections is small compared to the
# population, then tracking selections is efficient, requiring
# only a small set and an occasional reselection. For
# a larger number of selections, the pool tracking method is
# preferred since the list takes less space than the
# set and it doesn't suffer from frequent reselections.
当K相对N较小时,那么使用python set记录已选择的元素位置,重试的概率也会较小。当K较大时,就用list拷贝原序列。显然,这是一个 hybrid algorithm实现,不管输入如何,都能有较好的性能。
因此,算法的实现主要考虑的是额外使用的内存,如果list拷贝原序列内存占用少,那么用部分shuffle;如果set占用内存少,那么使用记录已选项的办法。
因此核心的问题,就是对使用的内存的判断。看代码,有几个magic num,其实就其中在这三行:
涉及到两个magic num:21 与 5, 还有一个公式。
magic 21
代码中是有注释的,即21是small 减去 empty list的大小。但是,我并没有搞懂为啥是21.
对于64位的python:
>>> import sys
>>> sys.getsizeof(set())
232
>>> sys.getsizeof([])
72
可以看到,二者相差160。另外,在Linux下,这个size应该都是8的倍数,所以至今不知道21是咋来的
magic 5
这个比较好理解,新创建的set默认会分配一个长度为8的数组。
当set中的元素超过了容量的2/3,那么会开辟新的存储空间,因此,所以当set中的元素小于等于5个时,使用默认的小数组,无需额外的空间
公式: 4 ** _ceil(_log(k * 3, 4))
log是取对数,base是4, ceil是向上取整。所以上面的公式,其范围在[3K, 12K]之间。
为什么有这么公式,应该是来自这段代码(setobject.c::set_add_entry)
set的容量临界值是3/2 * len(set), 超过了这个值,那么会分配四倍的空间。那么set分配的容量(N)与元素数目(K)的比例大致是 [3/2, 12/2]。
由于set中,一个setentry包含16个字节(8个字节的元素本身,以及8个字节的hash值),而list中一个元素只占用8个字节。所以当对比set与list的内存消耗是,上述的比例乘以了2.
random.sample有没有问题
当序列的长度小于K个元素所占用的空间时,使用的是部分shuffle的算法,当然,为了避免修改原序列,做了一个list拷贝。
否则使用随机抽取且放回的算法,需要注意的是,在这个时候, N的范围是[3K, 12K],即此时K是不可能趋近于N的,按照之前推导的公式 N(logN - log(N-K+1)), 时间复杂度均为O(K)。
因此,不管序列的长度与K的大小关系如何,时间复杂度都是O(K),且保证使用的内存最少。
这里也吐槽一下,在这个函数的docsting里面,提到的是对sequence进行随机采样,没有提到支持dict set,按照对ABC的理解,collections.Sequence 是不包含dict,set的。但事实上这个函数又是支持这两个类型的参数的。更令人费解的是,对set类型的参数,是通过捕获异常之后转换成tuple来来支持的。
random.sample还有这么一个特性:
The resulting list is in selection order so that all sub-slices will also be valid random samples. This allows raffle winners (the sample) to be partitioned into grand prize and second place winners (the subslices).
就是说,对于随机采样的结果,其sub slice也是符合随机采样结果的,即sample(population, K)[0, M] === sample(population, M), M<=K。在上面提到的各种sample方法中,水塘抽样算法是不满足这个特性的。
总结
本文记录了在使用python random模块时的一些思考与测试。搞清楚了random.shuffle, random.sample两个函数的实现原理与时间复杂度。
不过,还有两个没有思考清楚的问题
第一:myshuffle的实现中, p[i][j]的公式推导
第二:random.sample中,21 这个magic num是怎么来的
如果园友知道答案,还望不吝赐教