get_env_shape


def get_env_shape(task : str) -> Tuple[int, int]:
env = get_env(task)
obs_dim = env.observation_space.shape
action_space = env.action_space
if len(obs_dim) == 1:
obs_dim = obs_dim[0]

if hasattr(env.action_space, 'n'):
act_dim = env.action_space.n
else:
act_dim = action_space.shape[0]

return obs_dim, act_dim



references:
offlinerl/neorl
posted @ 2022-04-25 21:49  呦呦南山  阅读(25)  评论(0编辑  收藏  举报