np.random.choices的使用
在看莫烦python的RL源码时,他的DDPG记忆库Memory的实现是这样写的:
class Memory(object): def __init__(self, capacity, dims): self.capacity = capacity self.data = np.zeros((capacity, dims)) self.pointer = 0 def store_transition(self, s, a, r, s_): transition = np.hstack((s, a, [r], s_)) index = self.pointer % self.capacity # replace the old memory with new memory self.data[index, :] = transition self.pointer += 1 def sample(self, n): assert self.pointer >= self.capacity, 'Memory has not been fulfilled' indices = np.random.choice(self.capacity, size=n) return self.data[indices, :]
其中sample方法用assert断言pointer >= capacity,也就是说Memory必须满了才能学习。
我在设计一种方案,一开始往记忆库里存比较好的transition(也就是reward比较高的),要是等记忆库填满再学习好像有点浪费,因为会在填满之后很快被差的transition所替代,甚至好的transition不能填满Memory,从而不能有效学习好的经验。
此时就需要关注np.random.choice方法了,看源码解释:
def choice(a, size=None, replace=True, p=None): # real signature unknown; restored from __doc__ """ choice(a, size=None, replace=True, p=None) Generates a random sample from a given 1-D array .. versionadded:: 1.7.0 Parameters ----------- a : 1-D array-like or int If an ndarray, a random sample is generated from its elements. If an int, the random sample is generated as if a were np.arange(a) size : int or tuple of ints, optional Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples are drawn. Default is None, in which case a single value is returned. replace : boolean, optional Whether the sample is with or without replacement p : 1-D array-like, optional The probabilities associated with each entry in a. If not given the sample assumes a uniform distribution over all entries in a. Returns -------- samples : single item or ndarray The generated random samples
主要第一个参数为ndarray,如果给的是int,np会自动将其通过np.arange(a)转换为ndarray。
此处主要关注的是,a(我们使用int)< size时,np会怎么取?
上代码测试
import numpy as np samples = np.random.choice(3, 5) print(samples)
输出:
[2 1 2 1 1]
所以,是会从np.array(a)重复取,可以推断出,np.random.choice是“有放回地取”(具体我也没看源码,从重复情况来看,至少a<size时是这样的)
然后我分别测试了np.random.choice(5, 5)、np.random.choice(10, 5)等。多试几次会发现samples中确实是会有重复的。:
import numpy as np samples = np.random.choice(10, 5) print(samples) [3 4 3 4 5]