点此进入CSDN

点此添加QQ好友 加载失败时会显示




DQN玩cartpole游戏

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pygame
import sys
from collections import deque

# 定义DQN模型
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 2)  # 2个动作
        )

    def forward(self, x):
        return self.network(x)

# 经验回放
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

# 训练函数
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    states, actions, rewards, next_states, dones = memory.sample(BATCH_SIZE)
    
    states = torch.tensor(states, dtype=torch.float)
    next_states = torch.tensor(next_states, dtype=torch.float)
    actions = torch.tensor(actions, dtype=torch.long)
    rewards = torch.tensor(rewards, dtype=torch.float)
    dones = torch.tensor(dones, dtype=torch.float)

    current_q_values = model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_q_values = model(next_states).max(1)[0].detach()
    expected_q_values = rewards + 0.99 * next_q_values * (1 - dones)

    loss = criterion(current_q_values, expected_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 设置环境和模型
env = gym.make('CartPole-v1')
model = DQN()
memory = ReplayBuffer(10000)
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()
BATCH_SIZE = 128
EPSILON = 0.2

pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()

# 开始训练
num_episodes = 500
for episode in range(num_episodes):
    state = env.reset()
    total_reward = 0
    done = False
    state = state[0]
    while not done:
        if random.random() < EPSILON:
            action = env.action_space.sample()
        else:
            state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0)
            action = model(state_tensor).max(1)[1].item()
        
        next_state, reward, done, _,_ = env.step(action)
        memory.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        optimize_model()
        
        # Pygame visualization
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        screen.fill((255, 255, 255))
        cart_x = int(state[0] * 100 + 300)
        pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
        pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5)
        pygame.display.flip()
        clock.tick(60)

    EPSILON *= 0.995  # 减少探索率
    print(f'Episode {episode}: Total Reward = {total_reward}')

if __name__ == '__main__':
    main()

 

posted @ 2024-05-13 13:44  高颜值的殺生丸  阅读(14)  评论(0编辑  收藏  举报

作者信息

昵称:

刘新宇

园龄:4年6个月


粉丝:1209


QQ:522414928