actor critic 玩carpole游戏
import gym import torch import torch.nn as nn import torch.optim as optim import pygame import sys # 定义Actor网络 class Actor(nn.Module): def __init__(self): super(Actor, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 10), nn.ReLU(), nn.Linear(10, 2), nn.Softmax(dim=-1) ) def forward(self, x): return self.fc(x) # 定义Critic网络 class Critic(nn.Module): def __init__(self): super(Critic, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 10), nn.ReLU(), nn.Linear(10, 1) ) def forward(self, x): return self.fc(x) # 训练模型 def train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done): state = torch.tensor(state, dtype=torch.float) next_state = torch.tensor(next_state, dtype=torch.float) action = torch.tensor(action, dtype=torch.long) reward = torch.tensor(reward, dtype=torch.float) if done: next_value = 0 else: next_value = critic(next_state).detach() # Critic loss value = critic(state) expected_value = reward + 0.99 * next_value critic_loss = (value - expected_value).pow(2).mean() # Actor loss probs = actor(state) dist = torch.distributions.Categorical(probs) log_prob = dist.log_prob(action) advantage = (expected_value - value).detach() # TD error as advantage actor_loss = -log_prob * advantage # Update networks critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() actor_optimizer.zero_grad() actor_loss.backward() actor_optimizer.step() # 设置环境和模型 env = gym.make('CartPole-v1') actor = Actor() critic = Critic() actor_optimizer = optim.Adam(actor.parameters(), lr=0.001) critic_optimizer = optim.Adam(critic.parameters(), lr=0.01) pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() # 开始训练 for episode in range(10000): state = env.reset() done = False state = state[0] step= 0 while not done: step += 1 state_tensor = torch.tensor(state, dtype=torch.float) probs = actor(state_tensor) dist = torch.distributions.Categorical(probs) action = dist.sample().item() next_state, reward, done, _ ,_= env.step(action) train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done) state = next_state # Pygame visualization for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() # Drawing screen.fill((255, 255, 255)) cart_x = int(state[0] * 100 + 300) pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30)) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5) pygame.display.flip() clock.tick(200) print(f"第{episode}回合,玩{step}次挂了")
多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 如何调用 DeepSeek 的自然语言处理 API 接口并集成到在线客服系统
· 【译】Visual Studio 中新的强大生产力特性
· 2025年我用 Compose 写了一个 Todo App