get_env_shape
def get_env_shape(task : str) -> Tuple[int, int]:
env = get_env(task)
obs_dim = env.observation_space.shape
action_space = env.action_space
if len(obs_dim) == 1:
obs_dim = obs_dim[0]
if hasattr(env.action_space, 'n'):
act_dim = env.action_space.n
else:
act_dim = action_space.shape[0]
return obs_dim, act_dim
references:
offlinerl/neorl