一个蒙特卡洛树搜索的例子

""" My Monte Carlo Tree Search Demo """

import argparse
import math
import random

from copy import deepcopy
from typing_extensions import Self


def parse_args() -> argparse.Namespace:
    """Parse arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, help="Fix random seed", default=0)
    parser.add_argument("--tape_length", type=int, help="Tape length", default=50)
    parser.add_argument(
        "--sample_times_limit", type=int, help="Sample times limit", default=100
    )
    parser.add_argument(
        "--exploration_constant", type=float, help="Exploration constant", default=1.0
    )
    return parser.parse_args()


def set_seed(seed: int) -> None:
    """Set seed for reproducibility."""
    random.seed(seed)


class Action:
    """Action class."""

    def __init__(self, write_position):
        self.write_position = write_position


class State:
    """State class."""

    def __init__(self, tape_length: int) -> None:
        self.tape = [0] * tape_length
        self.tape_length = tape_length
        self.possible_actions = []
        for i in range(self.tape_length):
            self.possible_actions.append(Action(write_position=i))
        self.written_times = 0

    def get_possible_actions(self) -> list:
        """Get possible actions."""
        if self.is_terminal():
            return []
        return self.possible_actions

    def take_action(self, action: Action) -> Self:
        """Take action."""
        if action is None:
            return self
        new_state = deepcopy(self)
        new_state.tape[action.write_position] = 1
        new_state.written_times = self.written_times + 1
        return new_state

    def is_terminal(self) -> bool:
        """Check if the state is terminal."""
        if self.written_times == self.tape_length:
            return True
        return False

    def get_reward(self) -> int:
        """Get reward."""
        return sum(self.tape)

    def show_tape(self) -> None:
        """Show tape."""
        print(self.tape)


class TreeNode:
    """Tree node class."""

    def __init__(self, state: State, parent: Self) -> None:
        self.state = state
        self.is_terminal = state.is_terminal()
        self.is_fully_expanded = self.is_terminal
        self.parent = parent
        self.num_visits = 0
        self.total_reward = 0
        self.children = {}


class MCTS:
    """Monte Carlo Tree Search class."""

    def __init__(self, iteration_limit: int, exploration_constant: float) -> None:
        self.search_limit = iteration_limit
        self.exploration_constant = exploration_constant

    def search(self, initial_state: State) -> Action:
        """Search for the best action."""
        if initial_state.is_terminal():
            return None
        root = TreeNode(initial_state, None)
        for _ in range(self.search_limit):
            node = self.select_node(root)
            reward = self.rollout(node.state)
            self.back_propogate(node, reward)
        return self.get_best_action_child(root, 0.0)[0]

    def select_node(self, node: TreeNode) -> TreeNode:
        """Select node."""
        while not node.is_terminal:
            if node.is_fully_expanded:
                _, node = self.get_best_action_child(node, self.exploration_constant)
            else:
                return self.expand(node)
        return node

    def get_best_action_child(self, node: TreeNode, exploration_value: float) -> tuple:
        """Get best child."""
        best_value = float("-inf")
        best_actions_children = []
        actions = node.state.get_possible_actions()
        for action in actions:
            child = node.children[action]
            if child.num_visits == 0:
                return action, child
            child_value = (
                child.total_reward / child.num_visits
                + exploration_value
                * math.sqrt(2 * math.log(node.num_visits) / child.num_visits)
            )
            if child_value > best_value:
                best_value = child_value
                best_actions_children = [[action, child]]
            elif child_value == best_value:
                best_actions_children.append([action, child])
        return random.choice(best_actions_children)

    def rollout(self, state: State) -> int:
        """Rollout."""
        while not state.is_terminal():
            action = random.choice(state.get_possible_actions())
            state = state.take_action(action)
        return state.get_reward()

    def back_propogate(self, node: TreeNode, reward: int) -> None:
        """Back propogate."""
        while node is not None:
            node.num_visits += 1
            node.total_reward += reward
            node = node.parent

    def expand(self, node: TreeNode) -> TreeNode:
        """Expand."""
        actions = node.state.get_possible_actions()
        for action in actions:
            if action not in node.children:
                new_node = TreeNode(node.state.take_action(action), node)
                node.children[action] = new_node
                if len(actions) == len(node.children):
                    node.is_fully_expanded = True
                    return new_node
        return None


if __name__ == "__main__":
    args = parse_args()
    set_seed(args.seed)
    game_state = State(args.tape_length)
    searcher = MCTS(
        iteration_limit=args.sample_times_limit,
        exploration_constant=args.exploration_constant,
    )
    for _ in range(args.tape_length):
        best_action = searcher.search(initial_state=game_state)
        game_state = game_state.take_action(best_action)
        game_state.show_tape()
    print("Final reward:", game_state.get_reward())

  

posted @ 2024-10-24 21:27  南乡水  阅读(4)  评论(0编辑  收藏  举报