{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "7rz9TTaUt547", "iPa7Uf7ltt-r", "eBkUA_MdFwpr" ], "gpuType": "L4", "machine_shape": "hm", "authorship_tag": "ABX9TyNVvTasDtJNNgNqrI7lCCAW", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# Building a Chess Engine based on CNN and one based on Vision Transformer" ], "metadata": { "id": "JtjQT-pmt-Ms" } }, { "cell_type": "markdown", "source": [ "## Installs and Imports" ], "metadata": { "id": "7rz9TTaUt547" } }, { "cell_type": "code", "source": [ "!pip install python-chess" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QDffgdH7bar-", "outputId": "73e2f865-7b40-4a05-e8a6-ecc8f7dfe35c" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting python-chess\n", " Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)\n", "Collecting chess<2,>=1 (from python-chess)\n", " Downloading chess-1.11.2.tar.gz (6.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.1/6.1 MB\u001b[0m \u001b[31m80.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Downloading python_chess-1.999-py3-none-any.whl (1.4 kB)\n", "Building wheels for collected packages: chess\n", " Building wheel for chess (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=05ce0aec95b34740e55ad4b33d066b405abc4b38cf207f512af7ee2e090a2722\n", " Stored in directory: /root/.cache/pip/wheels/83/1f/4e/8f4300f7dd554eb8de70ddfed96e94d3d030ace10c5b53d447\n", "Successfully built chess\n", "Installing collected packages: chess, python-chess\n", "Successfully installed chess-1.11.2 python-chess-1.999\n" ] } ] }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import chess\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")" ], "metadata": { "id": "hQXTOxy4pyN6", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "214a48d9-c691-4229-a672-52deef85f4e2" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Build the Neural Network Class" ], "metadata": { "id": "iPa7Uf7ltt-r" } }, { "cell_type": "code", "source": [ "class ChessNet(nn.Module):\n", "\n", " def __init__(self):\n", " super(ChessNet, self).__init__()\n", " # 12 planes and 8x8 board\n", " self.conv1 = nn.Conv2d(12, 64, 1)\n", " self.conv2 = nn.Conv2d(64, 128, 1)\n", " # After flattening:\n", " self.lin1 = nn.Linear(128 * 8 * 8, 512)\n", " self.lin2 = nn.Linear(512, 256)\n", " self.lin3 = nn.Linear(256, 1)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.conv1(x))\n", " x = F.relu(self.conv2(x))\n", " # Flatten the output of the convolutional layers before passing to linear layers\n", " x = x.view(-1, 128 * 8 * 8)\n", " x = F.relu(self.lin1(x))\n", " x = F.relu(self.lin2(x))\n", " x = self.lin3(x)\n", " return x\n", "\n", "\n", "test_model = ChessNet()\n", "\n", "# One test data point with shape: 12x8x8\n", "test_data = torch.randn(1, 12, 8, 8)\n", "\n", "output = test_model(test_data)\n", "print(output)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aMJG_2TUqPvW", "outputId": "c85a69fc-0ea8-4907-cb36-f6553d8cc3dd" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[-0.0571]], grad_fn=)\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Create a function that turns a Board into a Tensor" ], "metadata": { "id": "eBkUA_MdFwpr" } }, { "cell_type": "code", "source": [ "def board_to_tensor(board):\n", " tensor = torch.zeros((12, 8, 8))\n", " pieces = board.piece_map()\n", " for square, piece in pieces.items():\n", " if piece.color == chess.WHITE:\n", " tensor[0 + piece.piece_type][square // 8][square % 8] = 1\n", " else:\n", " tensor[5 + piece.piece_type][square // 8][square % 8] = 1\n", " return tensor\n", "\n", "# test board to tensor function\n", "test_board = chess.Board()\n", "test_tensor = board_to_tensor(test_board)\n", "print(test_tensor)\n", "\n" ], "metadata": { "id": "2OGSEelMFwVR", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "d9dfe85d-6a4b-4946-82ac-ed21fdba9d70" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 1., 0., 0., 0., 0., 1., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 1., 0., 0., 1., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", "\n", " [[1., 0., 0., 0., 0., 0., 0., 1.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 0., 1., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 0., 0., 1., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0., 0., 1., 0.]],\n", "\n", " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 1., 0., 0., 1., 0., 0.]],\n", "\n", " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 1.]],\n", "\n", " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0.]],\n", "\n", " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 1., 0., 0., 0.]]])\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "8a2d513f" }, "source": [ "## Define RL Training Parameters\n", "\n", "Define key parameters for reinforcement learning, such as learning rate, discount factor (gamma), exploration rate (epsilon), number of training episodes, and batch size for experience replay.\n" ] }, { "cell_type": "code", "metadata": { "id": "26e027a2" }, "source": [ "LEARNING_RATE = 5e-4\n", "GAMMA = 0.99\n", "EPSILON_START = 1.0\n", "EPSILON_END = 0.01\n", "EPSILON_DECAY = 0.995\n", "NUM_EPISODES = 3000\n", "BATCH_SIZE = 64\n", "BUFFER_SIZE = 10000" ], "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "e0863229" }, "source": [ "## Initialize Model and Optimizer\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "0aa0e1b2" }, "source": [ "model = ChessNet().to(device)\n", "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)" ], "execution_count": 6, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "aa6a5836" }, "source": [ "## Implement Experience Replay Buffer\n", "\n", "Create a simple experience replay buffer to store game transitions (state, action, reward, next_state, done) for training stability.\n" ] }, { "cell_type": "code", "metadata": { "id": "f6e19433" }, "source": [ "import random\n", "from collections import deque\n", "\n", "class ReplayBuffer:\n", " def __init__(self, capacity):\n", " self.buffer = deque(maxlen=capacity)\n", "\n", " def add(self, state, reward, next_state, done):\n", " self.buffer.append((state, reward, next_state, done))\n", "\n", " def sample(self, batch_size):\n", " if len(self.buffer) < batch_size:\n", " return [] # Not enough samples to form a batch\n", " experiences = random.sample(self.buffer, batch_size)\n", "\n", " # Unpack experience\n", " states, rewards, next_states, dones = zip(*experiences)\n", "\n", " return states, rewards, next_states, dones\n", "\n", " def __len__(self):\n", " return len(self.buffer)" ], "execution_count": 7, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Create a function where the model plays against itself" ], "metadata": { "id": "QQUs_Gj8dxHW" } }, { "cell_type": "markdown", "source": [ "In this function the model will evaluate all the next moves from white's perspective and play the most advantageous move.\n", "\n", "We have 3 different outcomes and rewards:\n", "- White wins: +1\n", "- Black wins: -1\n", "- Draw: 0\n", "\n", "We will add one policy of playing maximum 100 moves otherwise it is a draw --> Reward = 0" ], "metadata": { "id": "hG-M_1leeRBu" } }, { "cell_type": "code", "metadata": { "id": "075de970" }, "source": [ "import random\n", "\n", "def play_against_itself(model, epsilon, initial_board=None, max_moves=100, device='cpu'):\n", " if initial_board is None:\n", " board = chess.Board()\n", " else:\n", " board = initial_board.copy()\n", "\n", " game_transitions = []\n", "\n", " final_game_result_value = 0\n", "\n", " for move_count in range(max_moves):\n", " if board.is_game_over():\n", " break\n", "\n", " current_state_tensor = board_to_tensor(board).unsqueeze(0).to(device)\n", " legal_moves = list(board.legal_moves)\n", "\n", " if not legal_moves:\n", " # Stalemate or checkmate\n", " break\n", "\n", " best_move = None\n", " # exploration\n", " if random.random() < epsilon:\n", " random_move = random.choice(legal_moves)\n", " best_move = random_move\n", " else: # Exploitation\n", " if board.turn == chess.WHITE:\n", " best_evaluation = -float('inf') # White wants to maximize the score\n", " else:\n", " best_evaluation = float('inf') # Black wants to minimize the score\n", "\n", " # Evaluate all legal moves\n", " for move in legal_moves:\n", " temp_board = board.copy()\n", " temp_board.push(move)\n", "\n", " input_tensor = board_to_tensor(temp_board).unsqueeze(0).to(device)\n", " with torch.no_grad():\n", " evaluation = model(input_tensor).item()\n", "\n", " if board.turn == chess.WHITE:\n", " if evaluation > best_evaluation:\n", " best_evaluation = evaluation\n", " best_move = move\n", " else:\n", " if evaluation < best_evaluation:\n", " best_evaluation = evaluation\n", " best_move = move\n", "\n", " if best_move is not None:\n", " board.push(best_move)\n", " else:\n", " break\n", "\n", " next_state_tensor = board_to_tensor(board).unsqueeze(0).to(device)\n", " is_done_after_move = board.is_game_over()\n", "\n", " game_transitions.append((current_state_tensor, next_state_tensor, is_done_after_move))\n", "\n", " if is_done_after_move:\n", " break\n", "\n", " if board.is_game_over():\n", " result = board.result()\n", " if result == \"1-0\":\n", " final_game_result_value = 1 # White wins\n", " elif result == \"0-1\":\n", " final_game_result_value = -1 # Black wins\n", " else:\n", " final_game_result_value = 0 # Draw\n", "\n", " return game_transitions, final_game_result_value" ], "execution_count": 8, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "271301cc" }, "source": [ "## Implement RL Training Loop\n", "\n", "Create a training loop that plays multiple games, stores experiences in the replay buffer, samples from the buffer, calculates the Q-value loss, and updates the model's weights using backpropagation.\n" ] }, { "cell_type": "code", "metadata": { "id": "bb5cfc23", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "08251dff-d41d-4050-b226-a530fe02fecd" }, "source": [ "replay_buffer = ReplayBuffer(BUFFER_SIZE)\n", "epsilon = EPSILON_START\n", "episode_rewards = []\n", "\n", "# Mean Squared Error for Q-learning\n", "criterion = nn.MSELoss()\n", "\n", "print(\"Starting RL training loop...\")\n", "\n", "for episode in range(NUM_EPISODES):\n", " game_transitions, final_game_result = play_against_itself(model, epsilon, max_moves=100, device=device)\n", "\n", " # Process game_transitions to add to replay buffer\n", " for i, (current_state_tensor, next_state_tensor, is_done_after_move) in enumerate(game_transitions):\n", " reward = 0\n", " if is_done_after_move:\n", " reward = final_game_result\n", "\n", " state_to_buffer = current_state_tensor.squeeze(0).cpu()\n", " next_state_to_buffer = next_state_tensor.squeeze(0).cpu()\n", "\n", " replay_buffer.add(state_to_buffer, reward, next_state_to_buffer, is_done_after_move)\n", "\n", " episode_rewards.append(final_game_result)\n", "\n", " # Train the model if enough experiences are in the buffer\n", " if len(replay_buffer) >= BATCH_SIZE:\n", " states, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)\n", "\n", " states = torch.stack(states).float().to(device)\n", " rewards = torch.tensor(rewards).float().to(device)\n", " next_states = torch.stack(next_states).float().to(device)\n", " dones = torch.tensor(dones).bool().to(device)\n", "\n", " # Get predicted Q-values for current states\n", " current_q_values = model(states).squeeze(1)\n", "\n", " # Calculate target Q-values\n", " with torch.no_grad():\n", " next_q_values = model(next_states).squeeze(1)\n", " max_next_q_values = next_q_values\n", " target_q_values = rewards + GAMMA * max_next_q_values * (~dones)\n", "\n", " # Compute loss and perform backpropagation\n", " optimizer.zero_grad()\n", " loss = criterion(current_q_values, target_q_values)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Decay epsilon\n", " epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)\n", "\n", " if (episode + 1) % 100 == 0:\n", " avg_reward = sum(episode_rewards[-100:]) / 100\n", " print(f\"Episode {episode + 1}/{NUM_EPISODES}, Epsilon: {epsilon:.4f}, Average Reward (last 100): {avg_reward:.2f}\")\n", "\n", "print(\"RL training loop completed.\")\n", "print(f\"Final Epsilon: {epsilon:.4f}\")\n", "print(f\"Total Episodes: {len(episode_rewards)}\")\n", "print(f\"Average reward over all episodes: {sum(episode_rewards) / len(episode_rewards):.2f}\")" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Starting RL training loop...\n", "Episode 100/3000, Epsilon: 0.6058, Average Reward (last 100): 0.01\n", "Episode 200/3000, Epsilon: 0.3670, Average Reward (last 100): -0.06\n", "Episode 300/3000, Epsilon: 0.2223, Average Reward (last 100): -0.07\n", "Episode 400/3000, Epsilon: 0.1347, Average Reward (last 100): -0.07\n", "Episode 500/3000, Epsilon: 0.0816, Average Reward (last 100): -0.01\n", "Episode 600/3000, Epsilon: 0.0494, Average Reward (last 100): -0.01\n", "Episode 700/3000, Epsilon: 0.0299, Average Reward (last 100): 0.00\n", "Episode 800/3000, Epsilon: 0.0181, Average Reward (last 100): -0.01\n", "Episode 900/3000, Epsilon: 0.0110, Average Reward (last 100): 0.00\n", "Episode 1000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.03\n", "Episode 1100/3000, Epsilon: 0.0100, Average Reward (last 100): -0.03\n", "Episode 1200/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", "Episode 1300/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", "Episode 1400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.05\n", "Episode 1500/3000, Epsilon: 0.0100, Average Reward (last 100): 0.07\n", "Episode 1600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1700/3000, Epsilon: 0.0100, Average Reward (last 100): -0.02\n", "Episode 1800/3000, Epsilon: 0.0100, Average Reward (last 100): -0.06\n", "Episode 1900/3000, Epsilon: 0.0100, Average Reward (last 100): -0.02\n", "Episode 2000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.05\n", "Episode 2100/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", "Episode 2200/3000, Epsilon: 0.0100, Average Reward (last 100): -0.08\n", "Episode 2300/3000, Epsilon: 0.0100, Average Reward (last 100): -0.06\n", "Episode 2400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 2500/3000, Epsilon: 0.0100, Average Reward (last 100): -0.08\n", "Episode 2600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", "Episode 2700/3000, Epsilon: 0.0100, Average Reward (last 100): 0.04\n", "Episode 2800/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", "Episode 2900/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", "Episode 3000/3000, Epsilon: 0.0100, Average Reward (last 100): 0.07\n", "RL training loop completed.\n", "Final Epsilon: 0.0100\n", "Total Episodes: 3000\n", "Average reward over all episodes: -0.01\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Save Model" ], "metadata": { "id": "aZryCSqb0xGB" } }, { "cell_type": "code", "source": [ "torch.save(model.state_dict(), './chess_net_model.pth')" ], "metadata": { "id": "GIaJO-ep0zGv" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "# test loading\n", "loaded_model = ChessNet().to(device)\n", "loaded_model.load_state_dict(torch.load('./chess_net_model.pth', map_location=device))\n", "loaded_model.eval()\n", "\n", "board = chess.Board()\n", "evaluation = loaded_model(board_to_tensor(board).unsqueeze(0).to(device))\n", "print(evaluation)" ], "metadata": { "id": "1ygqRtG02fVt", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "fbc3ac6b-ed7c-4ed4-8f82-d088af6dab0c" }, "execution_count": 11, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[-0.0481]], device='cuda:0', grad_fn=)\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "20863772" }, "source": [ "## Define ChessViT Model\n", "\n", "Create a new Python class `ChessViT` that implements a Vision Transformer structure based on the paper \"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale\"" ] }, { "cell_type": "code", "metadata": { "id": "46b282ae" }, "source": [ "class MultiHeadSelfAttention(nn.Module):\n", " def __init__(self, embed_dim, num_heads):\n", " super(MultiHeadSelfAttention, self).__init__()\n", " self.embed_dim = embed_dim\n", " self.num_heads = num_heads\n", " self.head_dim = embed_dim // num_heads\n", " assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n", "\n", " self.qkv = nn.Linear(embed_dim, embed_dim * 3) # Projects to Q, K, V\n", " self.proj = nn.Linear(embed_dim, embed_dim)\n", "\n", " def forward(self, x):\n", " batch_size, seq_len, embed_dim = x.size()\n", "\n", " qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)\n", " q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3, B, H, S, D_H\n", "\n", " # Scaled Dot-Product Attention\n", " # (B, H, S, D_H) @ (B, H, D_H, S) -> (B, H, S, S)\n", " attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)\n", " attention_probs = F.softmax(attention_scores, dim=-1)\n", "\n", " # (B, H, S, S) @ (B, H, S, D_H) -> (B, H, S, D_H)\n", " output = torch.matmul(attention_probs, v)\n", "\n", " # Concatenate heads and apply final projection\n", " output = output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)\n", " output = self.proj(output)\n", " return output" ], "execution_count": 12, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "28956b08" }, "source": [ "class FeedForwardNetwork(nn.Module):\n", " def __init__(self, embed_dim, hidden_dim):\n", " super(FeedForwardNetwork, self).__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(embed_dim, hidden_dim),\n", " nn.GELU(),\n", " nn.Linear(hidden_dim, embed_dim)\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)" ], "execution_count": 13, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "dedf8bfd" }, "source": [ "class TransformerEncoderBlock(nn.Module):\n", " def __init__(self, embed_dim, num_heads, hidden_dim):\n", " super(TransformerEncoderBlock, self).__init__()\n", " self.norm1 = nn.LayerNorm(embed_dim)\n", " self.attn = MultiHeadSelfAttention(embed_dim, num_heads)\n", " self.norm2 = nn.LayerNorm(embed_dim)\n", " self.ffn = FeedForwardNetwork(embed_dim, hidden_dim)\n", "\n", " def forward(self, x):\n", " x = x + self.attn(self.norm1(x))\n", " x = x + self.ffn(self.norm2(x))\n", " return x" ], "execution_count": 14, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "6e1472bd" }, "source": [ "class ChessViT(nn.Module):\n", " def __init__(self, in_channels=12, image_size=8, patch_size=2, embed_dim=128, num_layers=2, num_heads=4, hidden_dim=None):\n", " super(ChessViT, self).__init__()\n", "\n", " self.patch_size = patch_size\n", " self.embed_dim = embed_dim\n", " self.num_layers = num_layers\n", " self.num_heads = num_heads\n", "\n", " if hidden_dim is None:\n", " hidden_dim = embed_dim * 4\n", "\n", " assert image_size % patch_size == 0, \"Image size must be divisible by patch size\"\n", " self.num_patches = (image_size // patch_size)**2\n", "\n", " # 1. Patch Embedding\n", " self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)\n", "\n", " # 2. Class Token\n", " self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n", "\n", " # 3. Positional Embedding\n", " # num_patches + 1 for the class token\n", " self.positional_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))\n", "\n", " # 4. Transformer Encoder Blocks\n", " self.transformer_encoder_blocks = nn.ModuleList(\n", " [TransformerEncoderBlock(embed_dim, num_heads, hidden_dim) for _ in range(num_layers)]\n", " )\n", "\n", " # 5. MLP Head for regression (evaluation score)\n", " self.mlp_head = nn.Sequential(\n", " nn.LayerNorm(embed_dim),\n", " nn.Linear(embed_dim, 1)\n", " )\n", "\n", " def forward(self, x):\n", " batch_size = x.shape[0]\n", "\n", " # Apply patch embedding (Conv2d output: B, C, H, W)\n", " x = self.patch_embedding(x)\n", " # Flatten and transpose to (B, num_patches, embed_dim)\n", " x = x.flatten(2).transpose(1, 2)\n", "\n", " # Prepend class token\n", " cls_token = self.cls_token.expand(batch_size, -1, -1) # (B, 1, embed_dim)\n", " x = torch.cat((cls_token, x), dim=1) # (B, num_patches + 1, embed_dim)\n", "\n", " # Add positional embedding\n", " x = x + self.positional_embedding # (B, num_patches + 1, embed_dim)\n", "\n", " # Pass through Transformer Encoder Blocks\n", " for block in self.transformer_encoder_blocks:\n", " x = block(x)\n", "\n", " # Extract the output for the class token (first element)\n", " cls_token_output = x[:, 0]\n", "\n", " # Pass through MLP head for final evaluation score\n", " evaluation_score = self.mlp_head(cls_token_output)\n", "\n", " return evaluation_score" ], "execution_count": 15, "outputs": [] }, { "cell_type": "code", "source": [ "model = ChessViT()\n", "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)" ], "metadata": { "id": "sXBYr7-3JPWs" }, "execution_count": 16, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9c555432", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7b28938c-dda5-40a1-a92d-555deb3ba0d5" }, "source": [ "replay_buffer = ReplayBuffer(BUFFER_SIZE)\n", "epsilon = EPSILON_START\n", "episode_rewards = []\n", "\n", "criterion = nn.MSELoss()\n", "\n", "print(\"Starting RL training loop...\")\n", "\n", "for episode in range(NUM_EPISODES):\n", " game_transitions, final_game_result = play_against_itself(model, epsilon, max_moves=100)\n", "\n", " for i, (current_state_tensor, next_state_tensor, is_done_after_move) in enumerate(game_transitions):\n", " reward = 0\n", " if is_done_after_move:\n", " reward = final_game_result\n", "\n", " state_to_buffer = current_state_tensor.squeeze(0)\n", " next_state_to_buffer = next_state_tensor.squeeze(0)\n", "\n", " replay_buffer.add(state_to_buffer, reward, next_state_to_buffer, is_done_after_move)\n", "\n", " episode_rewards.append(final_game_result)\n", "\n", " if len(replay_buffer) >= BATCH_SIZE:\n", " states, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)\n", "\n", " states = torch.stack(states).float()\n", " rewards = torch.tensor(rewards).float()\n", " next_states = torch.stack(next_states).float()\n", " dones = torch.tensor(dones).bool()\n", "\n", " current_q_values = model(states).squeeze(1)\n", "\n", " with torch.no_grad():\n", " next_q_values = model(next_states).squeeze(1)\n", " max_next_q_values = next_q_values\n", " target_q_values = rewards + GAMMA * max_next_q_values * (~dones)\n", "\n", " optimizer.zero_grad()\n", " loss = criterion(current_q_values, target_q_values)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)\n", "\n", " if (episode + 1) % 100 == 0:\n", " avg_reward = sum(episode_rewards[-100:]) / 100\n", " print(f\"Episode {episode + 1}/{NUM_EPISODES}, Epsilon: {epsilon:.4f}, Average Reward (last 100): {avg_reward:.2f}\")\n", "\n", "print(\"RL training loop completed.\")\n", "print(f\"Final Epsilon: {epsilon:.4f}\")\n", "print(f\"Total Episodes: {len(episode_rewards)}\")\n", "print(f\"Average reward over all episodes: {sum(episode_rewards) / len(episode_rewards):.2f}\")" ], "execution_count": 17, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Starting RL training loop...\n", "Episode 100/3000, Epsilon: 0.6058, Average Reward (last 100): -0.01\n", "Episode 200/3000, Epsilon: 0.3670, Average Reward (last 100): -0.03\n", "Episode 300/3000, Epsilon: 0.2223, Average Reward (last 100): -0.01\n", "Episode 400/3000, Epsilon: 0.1347, Average Reward (last 100): -0.02\n", "Episode 500/3000, Epsilon: 0.0816, Average Reward (last 100): 0.00\n", "Episode 600/3000, Epsilon: 0.0494, Average Reward (last 100): 0.00\n", "Episode 700/3000, Epsilon: 0.0299, Average Reward (last 100): 0.02\n", "Episode 800/3000, Epsilon: 0.0181, Average Reward (last 100): 0.00\n", "Episode 900/3000, Epsilon: 0.0110, Average Reward (last 100): 0.00\n", "Episode 1000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", "Episode 1100/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1200/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1300/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1500/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.02\n", "Episode 1700/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1800/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 1900/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 2000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", "Episode 2100/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", "Episode 2200/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", "Episode 2300/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 2400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 2500/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 2600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 2700/3000, Epsilon: 0.0100, Average Reward (last 100): 0.05\n", "Episode 2800/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", "Episode 2900/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "Episode 3000/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", "RL training loop completed.\n", "Final Epsilon: 0.0100\n", "Total Episodes: 3000\n", "Average reward over all episodes: -0.00\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "8cc8cdcd", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ec8c16b0-e15b-42c8-ffc5-02b9346d9035" }, "source": [ "torch.save(model.state_dict(), './chess_vit_model.pth')\n", "print(\"ChessViT model saved to ./chess_vit_model.pth\")" ], "execution_count": 18, "outputs": [ { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "ChessViT model saved to ./chess_vit_model.pth\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "aff6003b", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "d649fc7b-2c82-4f19-fd7b-4ea6386dc591" }, "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "\n", "loaded_vit_model = ChessViT().to(device)\n", "loaded_vit_model.load_state_dict(torch.load('./chess_vit_model.pth', map_location=device))\n", "loaded_vit_model.eval()\n", "\n", "board = chess.Board()\n", "evaluation = loaded_vit_model(board_to_tensor(board).unsqueeze(0).to(device))\n", "print(f\"Evaluation of default board by loaded ChessViT model: {evaluation.item():.4f}\")" ], "execution_count": 19, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n", "Evaluation of default board by loaded ChessViT model: -0.0031\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "78ee595b" }, "source": [ "## Let thje 2 models play against each other\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "75f2c74b" }, "source": [ "def play_models_against_each_other(white_model, black_model, max_moves=100, device='cpu'):\n", " board = chess.Board()\n", " white_model.eval()\n", " black_model.eval()\n", "\n", " for move_count in range(max_moves):\n", " if board.is_game_over():\n", " break\n", "\n", " legal_moves = list(board.legal_moves)\n", " if not legal_moves:\n", " break\n", "\n", " best_move = None\n", "\n", " if board.turn == chess.WHITE:\n", " current_model = white_model\n", " best_evaluation = -float('inf') # White wants to maximize the score\n", " else:\n", " current_model = black_model\n", " best_evaluation = float('inf') # Black wants to minimize the score\n", "\n", " for move in legal_moves:\n", " temp_board = board.copy()\n", " temp_board.push(move)\n", "\n", " input_tensor = board_to_tensor(temp_board).unsqueeze(0).to(device)\n", " with torch.no_grad():\n", " evaluation = current_model(input_tensor).item()\n", "\n", " if board.turn == chess.WHITE:\n", " if evaluation > best_evaluation:\n", " best_evaluation = evaluation\n", " best_move = move\n", " else:\n", " if evaluation < best_evaluation:\n", " best_evaluation = evaluation\n", " best_move = move\n", "\n", " if best_move is not None:\n", " board.push(best_move)\n", " else:\n", " break\n", "\n", " return board\n" ], "execution_count": 31, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9c608df5", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "850e2323-cd8b-4aab-b3e5-de8fa7eac3b1" }, "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "\n", "loaded_chessnet_model = ChessNet().to(device)\n", "loaded_chessnet_model.load_state_dict(torch.load('./chess_net_model.pth', map_location=device))\n", "loaded_chessnet_model.eval()\n" ], "execution_count": 27, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "ChessNet(\n", " (conv1): Conv2d(12, 64, kernel_size=(1, 1), stride=(1, 1))\n", " (conv2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n", " (lin1): Linear(in_features=8192, out_features=512, bias=True)\n", " (lin2): Linear(in_features=512, out_features=256, bias=True)\n", " (lin3): Linear(in_features=256, out_features=1, bias=True)\n", ")" ] }, "metadata": {}, "execution_count": 27 } ] }, { "cell_type": "code", "metadata": { "id": "7c92530e", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ca0822f5-4e12-4c1b-db6f-c1f5c8c14411" }, "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "\n", "loaded_chessvit_model = ChessViT().to(device)\n", "loaded_chessvit_model.load_state_dict(torch.load('./chess_vit_model.pth', map_location=device))\n", "loaded_chessvit_model.eval()" ], "execution_count": 28, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "ChessViT(\n", " (patch_embedding): Conv2d(12, 128, kernel_size=(2, 2), stride=(2, 2))\n", " (transformer_encoder_blocks): ModuleList(\n", " (0-1): 2 x TransformerEncoderBlock(\n", " (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (attn): MultiHeadSelfAttention(\n", " (qkv): Linear(in_features=128, out_features=384, bias=True)\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " )\n", " (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ffn): FeedForwardNetwork(\n", " (net): Sequential(\n", " (0): Linear(in_features=128, out_features=512, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (mlp_head): Sequential(\n", " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (1): Linear(in_features=128, out_features=1, bias=True)\n", " )\n", ")" ] }, "metadata": {}, "execution_count": 28 } ] }, { "cell_type": "code", "metadata": { "id": "04067aa1", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7d750185-2690-4a13-c179-c891d5f7c2d1" }, "source": [ "results = []\n", "for i in range(100):\n", " game_result = play_models_against_each_other(loaded_chessnet_model, loaded_chessvit_model, device=device)\n", " results.append(game_result.result())\n", "\n", "print(\"ChessNet (White) vs ChessViT (Black):\")\n", "print({result: results.count(result) for result in set(results)})\n" ], "execution_count": 29, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "ChessNet (White) vs ChessViT (Black):\n", "{'1/2-1/2': 100}\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "188527c5", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ba5138b9-a47f-4fcd-ab1b-8393649ea24b" }, "source": [ "results = []\n", "for i in range(100):\n", " game_result = play_models_against_each_other(loaded_chessvit_model, loaded_chessnet_model, device=device)\n", " results.append(game_result.result())\n", "\n", "print(\"ChessViT (White) vs ChessNet (Black):\")\n", "print({result: results.count(result) for result in set(results)})" ], "execution_count": 30, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "ChessViT (White) vs ChessNet (Black):\n", "{'1/2-1/2': 100}\n" ] } ] }, { "cell_type": "code", "source": [ "game_result = play_models_against_each_other(loaded_chessnet_model, loaded_chessvit_model, device=device)\n", "game_result\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 411 }, "id": "rEyFB14ziH5u", "outputId": "0deeccde-97e4-46b4-862f-ebc6f07d64ef" }, "execution_count": 32, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Board('2q2bn1/p2kp2r/5p2/p1p1nPp1/1rB1Pp1P/2NP4/2P3K1/R1B5 w - - 18 36')" ], "image/svg+xml": "
. . q . . b n .\np . . k p . . r\n. . . . . p . .\np . p . n P p .\n. r B . P p . P\n. . N P . . . .\n. . P . . . K .\nR . B . . . . .
" }, "metadata": {}, "execution_count": 32 } ] }, { "cell_type": "code", "source": [ "game_result = play_models_against_each_other(loaded_chessvit_model, loaded_chessnet_model, device=device)\n", "game_result\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 411 }, "id": "6S1JgndCiYsV", "outputId": "419c73a9-659c-414f-9392-38c4da335d22" }, "execution_count": 33, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Board('1nb3nr/2ppkppp/4p3/r1b5/p6q/1PPKPNPP/3PBP2/RNB1Q2R b - - 16 21')" ], "image/svg+xml": "
. n b . . . n r\n. . p p k p p p\n. . . . p . . .\nr . b . . . . .\np . . . . . . q\n. P P K P N P P\n. . . P B P . .\nR N B . Q . . R
" }, "metadata": {}, "execution_count": 33 } ] } ] }