sample_by_num

def sample_by_num(data_dict: dict, num: int):
"""
Sample num trajs from data_dict.
"""
samples = {}
for k, v in data_dict.items():
if k == "index":
samples[k] = v[0: num]
else:
samples[k] = v[0: int(data_dict["index"][num])]

return samples

from:
offlinerl/neorl
posted @ 2022-04-25 20:22  呦呦南山  阅读(18)  评论(0编辑  收藏  举报