一个蒙特卡洛树搜索的例子
""" My Monte Carlo Tree Search Demo """ import argparse import math import random from copy import deepcopy from typing_extensions import Self def parse_args() -> argparse.Namespace: """Parse arguments.""" parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, help="Fix random seed", default=0) parser.add_argument("--tape_length", type=int, help="Tape length", default=50) parser.add_argument( "--sample_times_limit", type=int, help="Sample times limit", default=100 ) parser.add_argument( "--exploration_constant", type=float, help="Exploration constant", default=1.0 ) return parser.parse_args() def set_seed(seed: int) -> None: """Set seed for reproducibility.""" random.seed(seed) class Action: """Action class.""" def __init__(self, write_position): self.write_position = write_position class State: """State class.""" def __init__(self, tape_length: int) -> None: self.tape = [0] * tape_length self.tape_length = tape_length self.possible_actions = [] for i in range(self.tape_length): self.possible_actions.append(Action(write_position=i)) self.written_times = 0 def get_possible_actions(self) -> list: """Get possible actions.""" if self.is_terminal(): return [] return self.possible_actions def take_action(self, action: Action) -> Self: """Take action.""" if action is None: return self new_state = deepcopy(self) new_state.tape[action.write_position] = 1 new_state.written_times = self.written_times + 1 return new_state def is_terminal(self) -> bool: """Check if the state is terminal.""" if self.written_times == self.tape_length: return True return False def get_reward(self) -> int: """Get reward.""" return sum(self.tape) def show_tape(self) -> None: """Show tape.""" print(self.tape) class TreeNode: """Tree node class.""" def __init__(self, state: State, parent: Self) -> None: self.state = state self.is_terminal = state.is_terminal() self.is_fully_expanded = self.is_terminal self.parent = parent self.num_visits = 0 self.total_reward = 0 self.children = {} class MCTS: """Monte Carlo Tree Search class.""" def __init__(self, iteration_limit: int, exploration_constant: float) -> None: self.search_limit = iteration_limit self.exploration_constant = exploration_constant def search(self, initial_state: State) -> Action: """Search for the best action.""" if initial_state.is_terminal(): return None root = TreeNode(initial_state, None) for _ in range(self.search_limit): node = self.select_node(root) reward = self.rollout(node.state) self.back_propogate(node, reward) return self.get_best_action_child(root, 0.0)[0] def select_node(self, node: TreeNode) -> TreeNode: """Select node.""" while not node.is_terminal: if node.is_fully_expanded: _, node = self.get_best_action_child(node, self.exploration_constant) else: return self.expand(node) return node def get_best_action_child(self, node: TreeNode, exploration_value: float) -> tuple: """Get best child.""" best_value = float("-inf") best_actions_children = [] actions = node.state.get_possible_actions() for action in actions: child = node.children[action] if child.num_visits == 0: return action, child child_value = ( child.total_reward / child.num_visits + exploration_value * math.sqrt(2 * math.log(node.num_visits) / child.num_visits) ) if child_value > best_value: best_value = child_value best_actions_children = [[action, child]] elif child_value == best_value: best_actions_children.append([action, child]) return random.choice(best_actions_children) def rollout(self, state: State) -> int: """Rollout.""" while not state.is_terminal(): action = random.choice(state.get_possible_actions()) state = state.take_action(action) return state.get_reward() def back_propogate(self, node: TreeNode, reward: int) -> None: """Back propogate.""" while node is not None: node.num_visits += 1 node.total_reward += reward node = node.parent def expand(self, node: TreeNode) -> TreeNode: """Expand.""" actions = node.state.get_possible_actions() for action in actions: if action not in node.children: new_node = TreeNode(node.state.take_action(action), node) node.children[action] = new_node if len(actions) == len(node.children): node.is_fully_expanded = True return new_node return None if __name__ == "__main__": args = parse_args() set_seed(args.seed) game_state = State(args.tape_length) searcher = MCTS( iteration_limit=args.sample_times_limit, exploration_constant=args.exploration_constant, ) for _ in range(args.tape_length): best_action = searcher.search(initial_state=game_state) game_state = game_state.take_action(best_action) game_state.show_tape() print("Final reward:", game_state.get_reward())