sample_by_num
def sample_by_num(data_dict: dict, num: int):
"""
Sample num trajs from data_dict.
"""
samples = {}
for k, v in data_dict.items():
if k == "index":
samples[k] = v[0: num]
else:
samples[k] = v[0: int(data_dict["index"][num])]
return samples
from:
offlinerl/neorl