gym创建环境、自定义gym环境

环境:half_cheetah.py

from os import path

import numpy as np

from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box

DEFAULT_CAMERA_CONFIG = {
    "distance": 4.0,
}


class MOHalfCheetahEnv(MujocoEnv, utils.EzPickle):
    metadata = {
        "render_modes": [
            "human",
            "rgb_array",
            "depth_array",
        ],
        "render_fps": 20,
    }

    def __init__(
            self,
            **kwargs,
    ):
        utils.EzPickle.__init__(
            self,
            **kwargs,
        )

        # 计算 observation_space
        observation_space = Box(
            low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
        )

        # init
        MujocoEnv.__init__(
            self,
            "half_cheetah.xml", # 直接使用库里面的
            5,
            observation_space=observation_space,
            default_camera_config=DEFAULT_CAMERA_CONFIG,
            **kwargs,
        )

        # mo相关属性
        self.reward_space = Box(low=-np.inf, high=np.inf, shape=(2,))
        self.reward_dim = 2

    def step(self, action):
        # pgmorl pdmorl 直接在这里对action进行裁剪动作
        action = np.clip(action, -1.0, 1.0)

        # 计算速度
        x_position_before = self.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        x_position_after = self.data.qpos[0]
        x_velocity = (x_position_after - x_position_before) / self.dt

        # observation
        observation = self._get_obs()

        # reward
        alive_bonus = 1
        reward_run = min(4.0, x_velocity) + alive_bonus
        reward_energy = 4.0 - 1.0 * np.square(action).sum() + alive_bonus
        vec_reward = np.array([reward_run, reward_energy], dtype=np.float32)

        # terminated truncated
        ang = self.data.qpos[2]
        # terminated = not (abs(ang) < np.deg2rad(50))  # 终止 pgmorl pdmorl有终止
        terminated = False  # 终止 pgmorl pdmorl有终止
        truncated = False  # 截断

        # info
        info = {}

        # render
        if self.render_mode == "human":
            self.render()

        return observation, vec_reward, terminated, truncated, info

    def _get_obs(self):
        position = self.data.qpos.flat.copy()
        velocity = self.data.qvel.flat.copy()

        position = position[1:]  # obs 维度17

        observation = np.concatenate((position, velocity)).ravel()
        return observation

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
            low=-0.1, high=0.1, size=self.model.nq
        )
        qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
        self.set_state(qpos, qvel)
        return self._get_obs()

注册、不检查环境

from gymnasium.envs.registration import register
import mo_gymnasium as mo_gym
from half_cheetah import MOHalfCheetahEnv

register(
    id="wx-half-v1",
    entry_point=MOHalfCheetahEnv,
    max_episode_steps=500,
)

if __name__ == '__main__':
    import gymnasium as gym

    # env = MOHalfCheetahEnv(render_mode="human")
    # env = MOHalfCheetahEnv()
    # env = mo_gym.make('mo-halfcheetah-v4')  # 无done 1000次
    # env = gym.make("HalfCheetah-v4") # 无done 1000次
    env = gym.make("wx-half-v1", disable_env_checker=True)

    done = False
    obv, info = env.reset(seed=5)
    env.action_space.seed(5)
    env.observation_space.seed(5)

    print(type(env))

    steps = 0
    while not done:
        action = env.action_space.sample()
        obv, r, d1, d2, _ = env.step(action)
        # print(r)
        done = d1 or d2
        steps += 1
        print(steps)

    print(steps)

posted @ 2024-08-15 21:11  Wei_Xiong  阅读(67)  评论(0编辑  收藏  举报
WX:我是来搞笑的