Python小练习:object类型数据加载
作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/
给定npy文件,用Python加载后,发现该数据类型dtype=object,本文介绍object类型数据的调用/加载方法,并将数据转化为图像,保存为png与gif格式。
所用数据pool.npy与TrajPool.npy为:https://files-cdn.cnblogs.com/files/kailugaji/Pool_Datasets.rar?t=1682420448&download=true
1. object_load.py
1 # -*- coding: utf-8 -*- 2 # Author:凯鲁嘎吉 Coral Gajic 3 # https://www.cnblogs.com/kailugaji/ 4 # Python小练习:object类型数据加载 5 # 以强化学习经验回放池数据为例 6 # 数据来源:DeepMind Control Suite中的cheetah-run 7 # 在当前时刻状态下,智能体随机产生动作,与环境交互,得到下一步的状态与奖励 8 # 交互50次,得到由50个样本集组成的经验回放池:{s, a, s', r, ter} 9 # 分别表示:当前时刻状态、动作、下一步的状态、奖励、终止符 10 import numpy as np 11 import torchvision.transforms as transforms 12 import matplotlib.pyplot as plt 13 from matplotlib import animation 14 # DMControlEnv("cheetah","run") 15 16 def save_frames_as_gif(frames, path, index): 17 filename = 'gym_'+ index + '.gif' 18 patch = plt.imshow(frames[0]) 19 plt.axis('off') 20 def animate(i): 21 patch.set_data(frames[i]) 22 anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50, repeat = True, repeat_delay = 10) 23 anim.save(path + filename, writer='pillow', fps=60) 24 return anim 25 26 num = 32 27 dataset = np.load(r'./pool.npy') 28 print('数据类型:', 'dtype =', dataset.dtype) 29 # dtype=object 30 observations = dataset.item()['observations'] # (50, 9, 64, 64) 31 print('样本个数:', len(observations)) # 50 32 print('每个样本包含的键名称:', dataset.item().keys()) 33 # dict_keys(['observations', 'next_observations', 'actions', 'rewards', 'terminals']) 34 next_observations = dataset.item()['next_observations'] # (50, 9, 64, 64) 35 terminals = dataset.item()['terminals'] # (50, 1) 36 rewards = dataset.item()['rewards'] # (50, 1) 37 actions = dataset.item()['actions'] # (50, 6) 38 toPIL = transforms.ToPILImage() 39 frames = [] 40 fig = plt.figure(figsize=(15, 6)) 41 print('选取前%d个样本:'%num) 42 for j in range(num): 43 state = observations[j, 0:3, :, :].transpose((1, 2, 0)) 44 frames.append(state.astype(np.uint8)) 45 pic = toPIL(state.astype(np.uint8)) 46 plt.subplot(4, num//4, j+1) 47 plt.axis('off') 48 plt.imshow(pic) 49 print(j, 50 '\t奖励:', np.round(rewards[j], 3), 51 '\t动作:', np.round(actions[j], 3), 52 '\t终止符:', terminals[j]) 53 plt.savefig('cheetah-run.png', bbox_inches='tight', pad_inches=0.0, dpi=500) 54 plt.show() 55 save_frames_as_gif(frames, path = './', index = 'cheetah-run')
2. 结果
D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/dict/object_load.py" 数据类型: dtype = object 样本个数: 50 每个样本包含的键名称: dict_keys(['observations', 'next_observations', 'actions', 'rewards', 'terminals']) 选取前32个样本: 0 奖励: [0.125] 动作: [-0.807 0.717 -0.953 0.181 -0.283 0.841] 终止符: [0.] 1 奖励: [0.099] 动作: [-0.449 -0.307 0.473 0.719 0.055 -0.44 ] 终止符: [0.] 2 奖励: [0.075] 动作: [ 0.985 -0.704 -0.039 -0.867 0.092 -0.714] 终止符: [0.] 3 奖励: [0.108] 动作: [ 0.128 0.358 -0.66 0.788 -0.447 0.014] 终止符: [0.] 4 奖励: [0.105] 动作: [-0.871 0.691 0.301 0.521 -0.547 0.144] 终止符: [0.] 5 奖励: [0.043] 动作: [-0.687 0.79 0.455 0.584 0.179 0.568] 终止符: [0.] 6 奖励: [0.] 动作: [-0.022 0.306 0.66 0.978 -0.361 -0.869] 终止符: [0.] 7 奖励: [0.] 动作: [ 0.503 0.017 0.505 -0.649 -0.205 -0.179] 终止符: [0.] 8 奖励: [0.] 动作: [ 0.993 -0.424 -0.48 -0.127 0.341 0.458] 终止符: [0.] 9 奖励: [0.] 动作: [ 0.486 0.229 -0.494 -0.417 -0.93 0.258] 终止符: [0.] 10 奖励: [0.] 动作: [ 0.505 -0.009 -0.047 -0.004 0.64 -0.223] 终止符: [0.] 11 奖励: [0.] 动作: [ 0.103 0.038 0.757 -0.764 -0.852 0.023] 终止符: [0.] 12 奖励: [0.] 动作: [-0.385 -0.62 0.126 0.046 0.135 0.871] 终止符: [0.] 13 奖励: [0.] 动作: [-0.661 -0.92 0.128 0.705 0.841 0.32 ] 终止符: [0.] 14 奖励: [0.] 动作: [ 0.515 0.011 -0.085 -0.863 0.69 -0.899] 终止符: [0.] 15 奖励: [0.] 动作: [-0.16 0.08 0.342 -0.675 0.873 0.13 ] 终止符: [0.] 16 奖励: [0.] 动作: [-0.221 -0.102 0.862 -0.151 0.938 0.122] 终止符: [0.] 17 奖励: [0.] 动作: [ 0.915 0.735 -0.297 0.357 0.613 0.363] 终止符: [0.] 18 奖励: [0.] 动作: [ 0.752 -0.251 -0.505 -0.525 0.76 0.026] 终止符: [0.] 19 奖励: [0.] 动作: [-0.907 0.056 0.108 -0.921 -0.164 -0.508] 终止符: [0.] 20 奖励: [0.] 动作: [-0.522 -0.065 -0.66 -0.229 0.88 0.583] 终止符: [0.] 21 奖励: [0.] 动作: [-0.011 -0.137 0.209 0.014 -0.079 0.236] 终止符: [0.] 22 奖励: [0.] 动作: [-0.663 0.654 -0.068 -0.728 0.537 -0.359] 终止符: [0.] 23 奖励: [0.] 动作: [-0.602 -0.122 -0.313 -0.798 0.354 -0.558] 终止符: [0.] 24 奖励: [0.] 动作: [-0.667 0.071 0.508 -0.219 -0.007 0.041] 终止符: [0.] 25 奖励: [0.] 动作: [ 0.993 0.028 -0.229 0.809 0.502 0.281] 终止符: [0.] 26 奖励: [0.] 动作: [ 0.335 0.411 -0.902 -0.487 -0.564 0.109] 终止符: [0.] 27 奖励: [0.] 动作: [-0.509 -0.607 0.294 -0.391 0.997 0.134] 终止符: [0.] 28 奖励: [0.] 动作: [ 0.312 0.554 0.741 -0.098 -0.257 -0.768] 终止符: [0.] 29 奖励: [0.] 动作: [-0.855 -0.576 -0.122 -0.714 -0.436 -0.335] 终止符: [0.] 30 奖励: [0.] 动作: [ 0.797 0.024 -0.432 -0.378 -0.555 0.935] 终止符: [0.] 31 奖励: [0.] 动作: [ 0.768 0.445 0.59 -0.977 0.51 0.796] 终止符: [0.] Process finished with exit code 0
补充另一个npy文件:
代码:
1 # -*- coding: utf-8 -*- 2 # Author:凯鲁嘎吉 Coral Gajic 3 # https://www.cnblogs.com/kailugaji/ 4 # Python小练习:object类型数据加载 5 import numpy as np 6 import torchvision.transforms as transforms 7 import matplotlib.pyplot as plt 8 from matplotlib import animation 9 # DMControlEnv("cheetah","run") 10 11 def save_frames_as_gif(frames, path, index): 12 filename = 'gym_'+ index + '_traj.gif' 13 patch = plt.imshow(frames[0]) 14 plt.axis('off') 15 def animate(i): 16 patch.set_data(frames[i]) 17 anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50, repeat = True, repeat_delay = 10) 18 anim.save(path + filename, writer='pillow', fps=60) 19 return anim 20 21 num = 32 22 dataset = np.load(r'./TrajPool.npy') 23 print('数据类型:', 'dtype =', dataset.dtype) 24 # dtype=object 25 print('每个样本包含的键名称:', dataset.item().keys()) 26 # dict_keys(['frames', 'actions', 'rewards', 'terminals']) 27 observations = dataset.item()['frames'] # (300, 3, 64, 64) 28 print('样本个数:', len(observations)) # 300 29 terminals = dataset.item()['terminals'] # (300, 1) 30 rewards = dataset.item()['rewards'] # (300, 1) 31 actions = dataset.item()['actions'] # (300, 6) 32 toPIL = transforms.ToPILImage() 33 frames = [] 34 fig = plt.figure(figsize=(15, 6)) 35 print('选取前%d个样本:'%num) 36 for j in range(num): 37 state = observations[j, :, :, :].transpose((1, 2, 0)) 38 frames.append(state.astype(np.uint8)) 39 pic = toPIL(state.astype(np.uint8)) 40 plt.subplot(4, num//4, j+1) 41 plt.axis('off') 42 plt.imshow(pic) 43 print(j, 44 '\t奖励:', np.round(rewards[j], 3), 45 '\t动作:', np.round(actions[j], 3), 46 '\t终止符:', terminals[j]) 47 plt.savefig('cheetah-run-traj.png', bbox_inches='tight', pad_inches=0.0, dpi=500) 48 plt.show() 49 save_frames_as_gif(frames, path = './', index = 'cheetah-run')
结果:
D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/load_npy/object_load_traj.py" 数据类型: dtype = object 每个样本包含的键名称: dict_keys(['frames', 'actions', 'rewards', 'terminals']) 样本个数: 300 选取前32个样本: 0 奖励: [0.] 动作: [-0.491 -0.432 0.258 -0.417 0.777 -0.192] 终止符: [0.] 1 奖励: [0.] 动作: [-0.395 0.082 -0.993 -0.983 0.353 0.559] 终止符: [0.] 2 奖励: [0.] 动作: [ 0.655 -0.55 -0.464 -0.59 -0.836 -0.251] 终止符: [0.] 3 奖励: [0.] 动作: [-0.312 0.005 -0.037 -0.916 0.555 0.06 ] 终止符: [0.] 4 奖励: [0.] 动作: [ 0.495 -0.105 -0.743 0.734 -0.603 -0.588] 终止符: [0.] 5 奖励: [0.] 动作: [-0.064 0.105 0.172 0.498 -0.984 -0.174] 终止符: [0.] 6 奖励: [0.] 动作: [ 0.652 0.64 -0.743 0.108 0.663 -0.094] 终止符: [0.] 7 奖励: [0.] 动作: [-0.12 0.795 0.422 0.401 0.829 0.094] 终止符: [0.] 8 奖励: [0.] 动作: [0. 0. 0. 0. 0. 0.] 终止符: [0.] 9 奖励: [0.] 动作: [0. 0. 0. 0. 0. 0.] 终止符: [0.] 10 奖励: [0.] 动作: [0. 0. 0. 0. 0. 0.] 终止符: [0.] 11 奖励: [0.] 动作: [-0.246 -0.135 0.181 -0.473 0.584 -0.166] 终止符: [0.] 12 奖励: [0.] 动作: [ 0.052 -0.931 0.894 0.028 -0.669 0.218] 终止符: [0.] 13 奖励: [0.] 动作: [ 0.974 0.133 -0.692 -0.208 0.065 -0.746] 终止符: [0.] 14 奖励: [0.] 动作: [ 0.834 -0.767 0.423 -0.127 0.133 -0.662] 终止符: [0.] 15 奖励: [0.] 动作: [-0.893 0.482 0.973 0.219 -0.745 0.335] 终止符: [0.] 16 奖励: [0.] 动作: [ 0.777 -0.479 -0.601 0.209 0.435 0.342] 终止符: [0.] 17 奖励: [0.] 动作: [ 0.936 -0.639 0.932 -0.909 -0.519 0.674] 终止符: [0.] 18 奖励: [0.] 动作: [-0.913 0.719 -0.84 -0.065 0.845 0.524] 终止符: [0.] 19 奖励: [0.] 动作: [-0.156 0.533 0.796 -0.911 -0.99 0.207] 终止符: [0.] 20 奖励: [0.] 动作: [ 0.343 0.445 0.183 0.306 0.884 -0.94 ] 终止符: [0.] 21 奖励: [0.] 动作: [-0.157 -0.648 -0.125 -0.268 0.123 0.771] 终止符: [0.] 22 奖励: [0.] 动作: [ 0.44 0.574 -0.65 0.183 -0.043 -0.823] 终止符: [0.] 23 奖励: [0.] 动作: [ 0.855 -0.087 -0.347 -0.616 -0.843 0.184] 终止符: [0.] 24 奖励: [0.] 动作: [-0.269 -0.191 0.047 -0.363 -0.193 -0.786] 终止符: [0.] 25 奖励: [0.] 动作: [-0.472 0.426 0.367 0.316 0.947 0.675] 终止符: [0.] 26 奖励: [0.] 动作: [-0.058 0.245 0.744 -0.542 0.754 0.616] 终止符: [0.] 27 奖励: [0.] 动作: [-0.429 -0.531 0.737 -0.507 -0.712 -0.692] 终止符: [0.] 28 奖励: [0.] 动作: [-0.251 -0.844 0.539 0.258 -0.111 0.73 ] 终止符: [0.] 29 奖励: [0.03] 动作: [0.54 0.822 0.214 0.598 0.653 0.473] 终止符: [0.] 30 奖励: [0.085] 动作: [-0.365 -0.202 -0.25 0.548 0.461 -0.922] 终止符: [0.] 31 奖励: [0.121] 动作: [ 0.301 0.356 -0.387 0.764 0.333 0.725] 终止符: [0.] Process finished with exit code 0
3. 参考文献
[2] Liu Q, Zhou Q, Yang R, et al. Robust Representation Learning by Clustering with Bisimulation Metrics for Visual Reinforcement Learning with Distractions[C]. AAAI, 2023.