diff --git a/ChessEngine.ipynb b/ChessEngine.ipynb new file mode 100644 index 0000000..0d4e3ea --- /dev/null +++ b/ChessEngine.ipynb @@ -0,0 +1,1350 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [ + "7rz9TTaUt547", + "iPa7Uf7ltt-r", + "eBkUA_MdFwpr" + ], + "gpuType": "L4", + "machine_shape": "hm", + "authorship_tag": "ABX9TyNVvTasDtJNNgNqrI7lCCAW", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Building a Chess Engine based on CNN and one based on Vision Transformer" + ], + "metadata": { + "id": "JtjQT-pmt-Ms" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Installs and Imports" + ], + "metadata": { + "id": "7rz9TTaUt547" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install python-chess" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QDffgdH7bar-", + "outputId": "73e2f865-7b40-4a05-e8a6-ecc8f7dfe35c" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting python-chess\n", + " Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)\n", + "Collecting chess<2,>=1 (from python-chess)\n", + " Downloading chess-1.11.2.tar.gz (6.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.1/6.1 MB\u001b[0m \u001b[31m80.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Downloading python_chess-1.999-py3-none-any.whl (1.4 kB)\n", + "Building wheels for collected packages: chess\n", + " Building wheel for chess (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=05ce0aec95b34740e55ad4b33d066b405abc4b38cf207f512af7ee2e090a2722\n", + " Stored in directory: /root/.cache/pip/wheels/83/1f/4e/8f4300f7dd554eb8de70ddfed96e94d3d030ace10c5b53d447\n", + "Successfully built chess\n", + "Installing collected packages: chess, python-chess\n", + "Successfully installed chess-1.11.2 python-chess-1.999\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "import chess\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")" + ], + "metadata": { + "id": "hQXTOxy4pyN6", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "214a48d9-c691-4229-a672-52deef85f4e2" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Build the Neural Network Class" + ], + "metadata": { + "id": "iPa7Uf7ltt-r" + } + }, + { + "cell_type": "code", + "source": [ + "class ChessNet(nn.Module):\n", + "\n", + " def __init__(self):\n", + " super(ChessNet, self).__init__()\n", + " # 12 planes and 8x8 board\n", + " self.conv1 = nn.Conv2d(12, 64, 1)\n", + " self.conv2 = nn.Conv2d(64, 128, 1)\n", + " # After flattening:\n", + " self.lin1 = nn.Linear(128 * 8 * 8, 512)\n", + " self.lin2 = nn.Linear(512, 256)\n", + " self.lin3 = nn.Linear(256, 1)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.conv1(x))\n", + " x = F.relu(self.conv2(x))\n", + " # Flatten the output of the convolutional layers before passing to linear layers\n", + " x = x.view(-1, 128 * 8 * 8)\n", + " x = F.relu(self.lin1(x))\n", + " x = F.relu(self.lin2(x))\n", + " x = self.lin3(x)\n", + " return x\n", + "\n", + "\n", + "test_model = ChessNet()\n", + "\n", + "# One test data point with shape: 12x8x8\n", + "test_data = torch.randn(1, 12, 8, 8)\n", + "\n", + "output = test_model(test_data)\n", + "print(output)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aMJG_2TUqPvW", + "outputId": "c85a69fc-0ea8-4907-cb36-f6553d8cc3dd" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor([[-0.0571]], grad_fn=)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a function that turns a Board into a Tensor" + ], + "metadata": { + "id": "eBkUA_MdFwpr" + } + }, + { + "cell_type": "code", + "source": [ + "def board_to_tensor(board):\n", + " tensor = torch.zeros((12, 8, 8))\n", + " pieces = board.piece_map()\n", + " for square, piece in pieces.items():\n", + " if piece.color == chess.WHITE:\n", + " tensor[0 + piece.piece_type][square // 8][square % 8] = 1\n", + " else:\n", + " tensor[5 + piece.piece_type][square // 8][square % 8] = 1\n", + " return tensor\n", + "\n", + "# test board to tensor function\n", + "test_board = chess.Board()\n", + "test_tensor = board_to_tensor(test_board)\n", + "print(test_tensor)\n", + "\n" + ], + "metadata": { + "id": "2OGSEelMFwVR", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d9dfe85d-6a4b-4946-82ac-ed21fdba9d70" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 1., 0., 0., 0., 0., 1., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 1., 0., 0., 1., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[1., 0., 0., 0., 0., 0., 0., 1.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 1., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 1., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 1., 0., 0., 0., 0., 1., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 1., 0., 0., 1., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 1.]],\n", + "\n", + " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 1., 0., 0., 0., 0.]],\n", + "\n", + " [[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 1., 0., 0., 0.]]])\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8a2d513f" + }, + "source": [ + "## Define RL Training Parameters\n", + "\n", + "Define key parameters for reinforcement learning, such as learning rate, discount factor (gamma), exploration rate (epsilon), number of training episodes, and batch size for experience replay.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "26e027a2" + }, + "source": [ + "LEARNING_RATE = 5e-4\n", + "GAMMA = 0.99\n", + "EPSILON_START = 1.0\n", + "EPSILON_END = 0.01\n", + "EPSILON_DECAY = 0.995\n", + "NUM_EPISODES = 3000\n", + "BATCH_SIZE = 64\n", + "BUFFER_SIZE = 10000" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e0863229" + }, + "source": [ + "## Initialize Model and Optimizer\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0aa0e1b2" + }, + "source": [ + "model = ChessNet().to(device)\n", + "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aa6a5836" + }, + "source": [ + "## Implement Experience Replay Buffer\n", + "\n", + "Create a simple experience replay buffer to store game transitions (state, action, reward, next_state, done) for training stability.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "f6e19433" + }, + "source": [ + "import random\n", + "from collections import deque\n", + "\n", + "class ReplayBuffer:\n", + " def __init__(self, capacity):\n", + " self.buffer = deque(maxlen=capacity)\n", + "\n", + " def add(self, state, reward, next_state, done):\n", + " self.buffer.append((state, reward, next_state, done))\n", + "\n", + " def sample(self, batch_size):\n", + " if len(self.buffer) < batch_size:\n", + " return [] # Not enough samples to form a batch\n", + " experiences = random.sample(self.buffer, batch_size)\n", + "\n", + " # Unpack experience\n", + " states, rewards, next_states, dones = zip(*experiences)\n", + "\n", + " return states, rewards, next_states, dones\n", + "\n", + " def __len__(self):\n", + " return len(self.buffer)" + ], + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a function where the model plays against itself" + ], + "metadata": { + "id": "QQUs_Gj8dxHW" + } + }, + { + "cell_type": "markdown", + "source": [ + "In this function the model will evaluate all the next moves from white's perspective and play the most advantageous move.\n", + "\n", + "We have 3 different outcomes and rewards:\n", + "- White wins: +1\n", + "- Black wins: -1\n", + "- Draw: 0\n", + "\n", + "We will add one policy of playing maximum 100 moves otherwise it is a draw --> Reward = 0" + ], + "metadata": { + "id": "hG-M_1leeRBu" + } + }, + { + "cell_type": "code", + "metadata": { + "id": "075de970" + }, + "source": [ + "import random\n", + "\n", + "def play_against_itself(model, epsilon, initial_board=None, max_moves=100, device='cpu'):\n", + " if initial_board is None:\n", + " board = chess.Board()\n", + " else:\n", + " board = initial_board.copy()\n", + "\n", + " game_transitions = []\n", + "\n", + " final_game_result_value = 0\n", + "\n", + " for move_count in range(max_moves):\n", + " if board.is_game_over():\n", + " break\n", + "\n", + " current_state_tensor = board_to_tensor(board).unsqueeze(0).to(device)\n", + " legal_moves = list(board.legal_moves)\n", + "\n", + " if not legal_moves:\n", + " # Stalemate or checkmate\n", + " break\n", + "\n", + " best_move = None\n", + " # exploration\n", + " if random.random() < epsilon:\n", + " random_move = random.choice(legal_moves)\n", + " best_move = random_move\n", + " else: # Exploitation\n", + " if board.turn == chess.WHITE:\n", + " best_evaluation = -float('inf') # White wants to maximize the score\n", + " else:\n", + " best_evaluation = float('inf') # Black wants to minimize the score\n", + "\n", + " # Evaluate all legal moves\n", + " for move in legal_moves:\n", + " temp_board = board.copy()\n", + " temp_board.push(move)\n", + "\n", + " input_tensor = board_to_tensor(temp_board).unsqueeze(0).to(device)\n", + " with torch.no_grad():\n", + " evaluation = model(input_tensor).item()\n", + "\n", + " if board.turn == chess.WHITE:\n", + " if evaluation > best_evaluation:\n", + " best_evaluation = evaluation\n", + " best_move = move\n", + " else:\n", + " if evaluation < best_evaluation:\n", + " best_evaluation = evaluation\n", + " best_move = move\n", + "\n", + " if best_move is not None:\n", + " board.push(best_move)\n", + " else:\n", + " break\n", + "\n", + " next_state_tensor = board_to_tensor(board).unsqueeze(0).to(device)\n", + " is_done_after_move = board.is_game_over()\n", + "\n", + " game_transitions.append((current_state_tensor, next_state_tensor, is_done_after_move))\n", + "\n", + " if is_done_after_move:\n", + " break\n", + "\n", + " if board.is_game_over():\n", + " result = board.result()\n", + " if result == \"1-0\":\n", + " final_game_result_value = 1 # White wins\n", + " elif result == \"0-1\":\n", + " final_game_result_value = -1 # Black wins\n", + " else:\n", + " final_game_result_value = 0 # Draw\n", + "\n", + " return game_transitions, final_game_result_value" + ], + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "271301cc" + }, + "source": [ + "## Implement RL Training Loop\n", + "\n", + "Create a training loop that plays multiple games, stores experiences in the replay buffer, samples from the buffer, calculates the Q-value loss, and updates the model's weights using backpropagation.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bb5cfc23", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "08251dff-d41d-4050-b226-a530fe02fecd" + }, + "source": [ + "replay_buffer = ReplayBuffer(BUFFER_SIZE)\n", + "epsilon = EPSILON_START\n", + "episode_rewards = []\n", + "\n", + "# Mean Squared Error for Q-learning\n", + "criterion = nn.MSELoss()\n", + "\n", + "print(\"Starting RL training loop...\")\n", + "\n", + "for episode in range(NUM_EPISODES):\n", + " game_transitions, final_game_result = play_against_itself(model, epsilon, max_moves=100, device=device)\n", + "\n", + " # Process game_transitions to add to replay buffer\n", + " for i, (current_state_tensor, next_state_tensor, is_done_after_move) in enumerate(game_transitions):\n", + " reward = 0\n", + " if is_done_after_move:\n", + " reward = final_game_result\n", + "\n", + " state_to_buffer = current_state_tensor.squeeze(0).cpu()\n", + " next_state_to_buffer = next_state_tensor.squeeze(0).cpu()\n", + "\n", + " replay_buffer.add(state_to_buffer, reward, next_state_to_buffer, is_done_after_move)\n", + "\n", + " episode_rewards.append(final_game_result)\n", + "\n", + " # Train the model if enough experiences are in the buffer\n", + " if len(replay_buffer) >= BATCH_SIZE:\n", + " states, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)\n", + "\n", + " states = torch.stack(states).float().to(device)\n", + " rewards = torch.tensor(rewards).float().to(device)\n", + " next_states = torch.stack(next_states).float().to(device)\n", + " dones = torch.tensor(dones).bool().to(device)\n", + "\n", + " # Get predicted Q-values for current states\n", + " current_q_values = model(states).squeeze(1)\n", + "\n", + " # Calculate target Q-values\n", + " with torch.no_grad():\n", + " next_q_values = model(next_states).squeeze(1)\n", + " max_next_q_values = next_q_values\n", + " target_q_values = rewards + GAMMA * max_next_q_values * (~dones)\n", + "\n", + " # Compute loss and perform backpropagation\n", + " optimizer.zero_grad()\n", + " loss = criterion(current_q_values, target_q_values)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Decay epsilon\n", + " epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)\n", + "\n", + " if (episode + 1) % 100 == 0:\n", + " avg_reward = sum(episode_rewards[-100:]) / 100\n", + " print(f\"Episode {episode + 1}/{NUM_EPISODES}, Epsilon: {epsilon:.4f}, Average Reward (last 100): {avg_reward:.2f}\")\n", + "\n", + "print(\"RL training loop completed.\")\n", + "print(f\"Final Epsilon: {epsilon:.4f}\")\n", + "print(f\"Total Episodes: {len(episode_rewards)}\")\n", + "print(f\"Average reward over all episodes: {sum(episode_rewards) / len(episode_rewards):.2f}\")" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Starting RL training loop...\n", + "Episode 100/3000, Epsilon: 0.6058, Average Reward (last 100): 0.01\n", + "Episode 200/3000, Epsilon: 0.3670, Average Reward (last 100): -0.06\n", + "Episode 300/3000, Epsilon: 0.2223, Average Reward (last 100): -0.07\n", + "Episode 400/3000, Epsilon: 0.1347, Average Reward (last 100): -0.07\n", + "Episode 500/3000, Epsilon: 0.0816, Average Reward (last 100): -0.01\n", + "Episode 600/3000, Epsilon: 0.0494, Average Reward (last 100): -0.01\n", + "Episode 700/3000, Epsilon: 0.0299, Average Reward (last 100): 0.00\n", + "Episode 800/3000, Epsilon: 0.0181, Average Reward (last 100): -0.01\n", + "Episode 900/3000, Epsilon: 0.0110, Average Reward (last 100): 0.00\n", + "Episode 1000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.03\n", + "Episode 1100/3000, Epsilon: 0.0100, Average Reward (last 100): -0.03\n", + "Episode 1200/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", + "Episode 1300/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", + "Episode 1400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.05\n", + "Episode 1500/3000, Epsilon: 0.0100, Average Reward (last 100): 0.07\n", + "Episode 1600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1700/3000, Epsilon: 0.0100, Average Reward (last 100): -0.02\n", + "Episode 1800/3000, Epsilon: 0.0100, Average Reward (last 100): -0.06\n", + "Episode 1900/3000, Epsilon: 0.0100, Average Reward (last 100): -0.02\n", + "Episode 2000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.05\n", + "Episode 2100/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", + "Episode 2200/3000, Epsilon: 0.0100, Average Reward (last 100): -0.08\n", + "Episode 2300/3000, Epsilon: 0.0100, Average Reward (last 100): -0.06\n", + "Episode 2400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 2500/3000, Epsilon: 0.0100, Average Reward (last 100): -0.08\n", + "Episode 2600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", + "Episode 2700/3000, Epsilon: 0.0100, Average Reward (last 100): 0.04\n", + "Episode 2800/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", + "Episode 2900/3000, Epsilon: 0.0100, Average Reward (last 100): 0.01\n", + "Episode 3000/3000, Epsilon: 0.0100, Average Reward (last 100): 0.07\n", + "RL training loop completed.\n", + "Final Epsilon: 0.0100\n", + "Total Episodes: 3000\n", + "Average reward over all episodes: -0.01\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Save Model" + ], + "metadata": { + "id": "aZryCSqb0xGB" + } + }, + { + "cell_type": "code", + "source": [ + "torch.save(model.state_dict(), './chess_net_model.pth')" + ], + "metadata": { + "id": "GIaJO-ep0zGv" + }, + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# test loading\n", + "loaded_model = ChessNet().to(device)\n", + "loaded_model.load_state_dict(torch.load('./chess_net_model.pth', map_location=device))\n", + "loaded_model.eval()\n", + "\n", + "board = chess.Board()\n", + "evaluation = loaded_model(board_to_tensor(board).unsqueeze(0).to(device))\n", + "print(evaluation)" + ], + "metadata": { + "id": "1ygqRtG02fVt", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "fbc3ac6b-ed7c-4ed4-8f82-d088af6dab0c" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor([[-0.0481]], device='cuda:0', grad_fn=)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "20863772" + }, + "source": [ + "## Define ChessViT Model\n", + "\n", + "Create a new Python class `ChessViT` that implements a Vision Transformer structure based on the paper \"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale\"" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "46b282ae" + }, + "source": [ + "class MultiHeadSelfAttention(nn.Module):\n", + " def __init__(self, embed_dim, num_heads):\n", + " super(MultiHeadSelfAttention, self).__init__()\n", + " self.embed_dim = embed_dim\n", + " self.num_heads = num_heads\n", + " self.head_dim = embed_dim // num_heads\n", + " assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n", + "\n", + " self.qkv = nn.Linear(embed_dim, embed_dim * 3) # Projects to Q, K, V\n", + " self.proj = nn.Linear(embed_dim, embed_dim)\n", + "\n", + " def forward(self, x):\n", + " batch_size, seq_len, embed_dim = x.size()\n", + "\n", + " qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)\n", + " q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3, B, H, S, D_H\n", + "\n", + " # Scaled Dot-Product Attention\n", + " # (B, H, S, D_H) @ (B, H, D_H, S) -> (B, H, S, S)\n", + " attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)\n", + " attention_probs = F.softmax(attention_scores, dim=-1)\n", + "\n", + " # (B, H, S, S) @ (B, H, S, D_H) -> (B, H, S, D_H)\n", + " output = torch.matmul(attention_probs, v)\n", + "\n", + " # Concatenate heads and apply final projection\n", + " output = output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)\n", + " output = self.proj(output)\n", + " return output" + ], + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "28956b08" + }, + "source": [ + "class FeedForwardNetwork(nn.Module):\n", + " def __init__(self, embed_dim, hidden_dim):\n", + " super(FeedForwardNetwork, self).__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(embed_dim, hidden_dim),\n", + " nn.GELU(),\n", + " nn.Linear(hidden_dim, embed_dim)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)" + ], + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "dedf8bfd" + }, + "source": [ + "class TransformerEncoderBlock(nn.Module):\n", + " def __init__(self, embed_dim, num_heads, hidden_dim):\n", + " super(TransformerEncoderBlock, self).__init__()\n", + " self.norm1 = nn.LayerNorm(embed_dim)\n", + " self.attn = MultiHeadSelfAttention(embed_dim, num_heads)\n", + " self.norm2 = nn.LayerNorm(embed_dim)\n", + " self.ffn = FeedForwardNetwork(embed_dim, hidden_dim)\n", + "\n", + " def forward(self, x):\n", + " x = x + self.attn(self.norm1(x))\n", + " x = x + self.ffn(self.norm2(x))\n", + " return x" + ], + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "6e1472bd" + }, + "source": [ + "class ChessViT(nn.Module):\n", + " def __init__(self, in_channels=12, image_size=8, patch_size=2, embed_dim=128, num_layers=2, num_heads=4, hidden_dim=None):\n", + " super(ChessViT, self).__init__()\n", + "\n", + " self.patch_size = patch_size\n", + " self.embed_dim = embed_dim\n", + " self.num_layers = num_layers\n", + " self.num_heads = num_heads\n", + "\n", + " if hidden_dim is None:\n", + " hidden_dim = embed_dim * 4\n", + "\n", + " assert image_size % patch_size == 0, \"Image size must be divisible by patch size\"\n", + " self.num_patches = (image_size // patch_size)**2\n", + "\n", + " # 1. Patch Embedding\n", + " self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)\n", + "\n", + " # 2. Class Token\n", + " self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n", + "\n", + " # 3. Positional Embedding\n", + " # num_patches + 1 for the class token\n", + " self.positional_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))\n", + "\n", + " # 4. Transformer Encoder Blocks\n", + " self.transformer_encoder_blocks = nn.ModuleList(\n", + " [TransformerEncoderBlock(embed_dim, num_heads, hidden_dim) for _ in range(num_layers)]\n", + " )\n", + "\n", + " # 5. MLP Head for regression (evaluation score)\n", + " self.mlp_head = nn.Sequential(\n", + " nn.LayerNorm(embed_dim),\n", + " nn.Linear(embed_dim, 1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size = x.shape[0]\n", + "\n", + " # Apply patch embedding (Conv2d output: B, C, H, W)\n", + " x = self.patch_embedding(x)\n", + " # Flatten and transpose to (B, num_patches, embed_dim)\n", + " x = x.flatten(2).transpose(1, 2)\n", + "\n", + " # Prepend class token\n", + " cls_token = self.cls_token.expand(batch_size, -1, -1) # (B, 1, embed_dim)\n", + " x = torch.cat((cls_token, x), dim=1) # (B, num_patches + 1, embed_dim)\n", + "\n", + " # Add positional embedding\n", + " x = x + self.positional_embedding # (B, num_patches + 1, embed_dim)\n", + "\n", + " # Pass through Transformer Encoder Blocks\n", + " for block in self.transformer_encoder_blocks:\n", + " x = block(x)\n", + "\n", + " # Extract the output for the class token (first element)\n", + " cls_token_output = x[:, 0]\n", + "\n", + " # Pass through MLP head for final evaluation score\n", + " evaluation_score = self.mlp_head(cls_token_output)\n", + "\n", + " return evaluation_score" + ], + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model = ChessViT()\n", + "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)" + ], + "metadata": { + "id": "sXBYr7-3JPWs" + }, + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "9c555432", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7b28938c-dda5-40a1-a92d-555deb3ba0d5" + }, + "source": [ + "replay_buffer = ReplayBuffer(BUFFER_SIZE)\n", + "epsilon = EPSILON_START\n", + "episode_rewards = []\n", + "\n", + "criterion = nn.MSELoss()\n", + "\n", + "print(\"Starting RL training loop...\")\n", + "\n", + "for episode in range(NUM_EPISODES):\n", + " game_transitions, final_game_result = play_against_itself(model, epsilon, max_moves=100)\n", + "\n", + " for i, (current_state_tensor, next_state_tensor, is_done_after_move) in enumerate(game_transitions):\n", + " reward = 0\n", + " if is_done_after_move:\n", + " reward = final_game_result\n", + "\n", + " state_to_buffer = current_state_tensor.squeeze(0)\n", + " next_state_to_buffer = next_state_tensor.squeeze(0)\n", + "\n", + " replay_buffer.add(state_to_buffer, reward, next_state_to_buffer, is_done_after_move)\n", + "\n", + " episode_rewards.append(final_game_result)\n", + "\n", + " if len(replay_buffer) >= BATCH_SIZE:\n", + " states, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)\n", + "\n", + " states = torch.stack(states).float()\n", + " rewards = torch.tensor(rewards).float()\n", + " next_states = torch.stack(next_states).float()\n", + " dones = torch.tensor(dones).bool()\n", + "\n", + " current_q_values = model(states).squeeze(1)\n", + "\n", + " with torch.no_grad():\n", + " next_q_values = model(next_states).squeeze(1)\n", + " max_next_q_values = next_q_values\n", + " target_q_values = rewards + GAMMA * max_next_q_values * (~dones)\n", + "\n", + " optimizer.zero_grad()\n", + " loss = criterion(current_q_values, target_q_values)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)\n", + "\n", + " if (episode + 1) % 100 == 0:\n", + " avg_reward = sum(episode_rewards[-100:]) / 100\n", + " print(f\"Episode {episode + 1}/{NUM_EPISODES}, Epsilon: {epsilon:.4f}, Average Reward (last 100): {avg_reward:.2f}\")\n", + "\n", + "print(\"RL training loop completed.\")\n", + "print(f\"Final Epsilon: {epsilon:.4f}\")\n", + "print(f\"Total Episodes: {len(episode_rewards)}\")\n", + "print(f\"Average reward over all episodes: {sum(episode_rewards) / len(episode_rewards):.2f}\")" + ], + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Starting RL training loop...\n", + "Episode 100/3000, Epsilon: 0.6058, Average Reward (last 100): -0.01\n", + "Episode 200/3000, Epsilon: 0.3670, Average Reward (last 100): -0.03\n", + "Episode 300/3000, Epsilon: 0.2223, Average Reward (last 100): -0.01\n", + "Episode 400/3000, Epsilon: 0.1347, Average Reward (last 100): -0.02\n", + "Episode 500/3000, Epsilon: 0.0816, Average Reward (last 100): 0.00\n", + "Episode 600/3000, Epsilon: 0.0494, Average Reward (last 100): 0.00\n", + "Episode 700/3000, Epsilon: 0.0299, Average Reward (last 100): 0.02\n", + "Episode 800/3000, Epsilon: 0.0181, Average Reward (last 100): 0.00\n", + "Episode 900/3000, Epsilon: 0.0110, Average Reward (last 100): 0.00\n", + "Episode 1000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", + "Episode 1100/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1200/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1300/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1500/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.02\n", + "Episode 1700/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1800/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 1900/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 2000/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", + "Episode 2100/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", + "Episode 2200/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", + "Episode 2300/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 2400/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 2500/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 2600/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 2700/3000, Epsilon: 0.0100, Average Reward (last 100): 0.05\n", + "Episode 2800/3000, Epsilon: 0.0100, Average Reward (last 100): -0.01\n", + "Episode 2900/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "Episode 3000/3000, Epsilon: 0.0100, Average Reward (last 100): 0.00\n", + "RL training loop completed.\n", + "Final Epsilon: 0.0100\n", + "Total Episodes: 3000\n", + "Average reward over all episodes: -0.00\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8cc8cdcd", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ec8c16b0-e15b-42c8-ffc5-02b9346d9035" + }, + "source": [ + "torch.save(model.state_dict(), './chess_vit_model.pth')\n", + "print(\"ChessViT model saved to ./chess_vit_model.pth\")" + ], + "execution_count": 18, + "outputs": [ + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "ChessViT model saved to ./chess_vit_model.pth\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "aff6003b", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d649fc7b-2c82-4f19-fd7b-4ea6386dc591" + }, + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "loaded_vit_model = ChessViT().to(device)\n", + "loaded_vit_model.load_state_dict(torch.load('./chess_vit_model.pth', map_location=device))\n", + "loaded_vit_model.eval()\n", + "\n", + "board = chess.Board()\n", + "evaluation = loaded_vit_model(board_to_tensor(board).unsqueeze(0).to(device))\n", + "print(f\"Evaluation of default board by loaded ChessViT model: {evaluation.item():.4f}\")" + ], + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda\n", + "Evaluation of default board by loaded ChessViT model: -0.0031\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "78ee595b" + }, + "source": [ + "## Let thje 2 models play against each other\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "75f2c74b" + }, + "source": [ + "def play_models_against_each_other(white_model, black_model, max_moves=100, device='cpu'):\n", + " board = chess.Board()\n", + " white_model.eval()\n", + " black_model.eval()\n", + "\n", + " for move_count in range(max_moves):\n", + " if board.is_game_over():\n", + " break\n", + "\n", + " legal_moves = list(board.legal_moves)\n", + " if not legal_moves:\n", + " break\n", + "\n", + " best_move = None\n", + "\n", + " if board.turn == chess.WHITE:\n", + " current_model = white_model\n", + " best_evaluation = -float('inf') # White wants to maximize the score\n", + " else:\n", + " current_model = black_model\n", + " best_evaluation = float('inf') # Black wants to minimize the score\n", + "\n", + " for move in legal_moves:\n", + " temp_board = board.copy()\n", + " temp_board.push(move)\n", + "\n", + " input_tensor = board_to_tensor(temp_board).unsqueeze(0).to(device)\n", + " with torch.no_grad():\n", + " evaluation = current_model(input_tensor).item()\n", + "\n", + " if board.turn == chess.WHITE:\n", + " if evaluation > best_evaluation:\n", + " best_evaluation = evaluation\n", + " best_move = move\n", + " else:\n", + " if evaluation < best_evaluation:\n", + " best_evaluation = evaluation\n", + " best_move = move\n", + "\n", + " if best_move is not None:\n", + " board.push(best_move)\n", + " else:\n", + " break\n", + "\n", + " return board\n" + ], + "execution_count": 31, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "9c608df5", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "850e2323-cd8b-4aab-b3e5-de8fa7eac3b1" + }, + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "loaded_chessnet_model = ChessNet().to(device)\n", + "loaded_chessnet_model.load_state_dict(torch.load('./chess_net_model.pth', map_location=device))\n", + "loaded_chessnet_model.eval()\n" + ], + "execution_count": 27, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "ChessNet(\n", + " (conv1): Conv2d(12, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (lin1): Linear(in_features=8192, out_features=512, bias=True)\n", + " (lin2): Linear(in_features=512, out_features=256, bias=True)\n", + " (lin3): Linear(in_features=256, out_features=1, bias=True)\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 27 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "7c92530e", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ca0822f5-4e12-4c1b-db6f-c1f5c8c14411" + }, + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "loaded_chessvit_model = ChessViT().to(device)\n", + "loaded_chessvit_model.load_state_dict(torch.load('./chess_vit_model.pth', map_location=device))\n", + "loaded_chessvit_model.eval()" + ], + "execution_count": 28, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "ChessViT(\n", + " (patch_embedding): Conv2d(12, 128, kernel_size=(2, 2), stride=(2, 2))\n", + " (transformer_encoder_blocks): ModuleList(\n", + " (0-1): 2 x TransformerEncoderBlock(\n", + " (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (attn): MultiHeadSelfAttention(\n", + " (qkv): Linear(in_features=128, out_features=384, bias=True)\n", + " (proj): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (ffn): FeedForwardNetwork(\n", + " (net): Sequential(\n", + " (0): Linear(in_features=128, out_features=512, bias=True)\n", + " (1): GELU(approximate='none')\n", + " (2): Linear(in_features=512, out_features=128, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (mlp_head): Sequential(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): Linear(in_features=128, out_features=1, bias=True)\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 28 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "04067aa1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7d750185-2690-4a13-c179-c891d5f7c2d1" + }, + "source": [ + "results = []\n", + "for i in range(100):\n", + " game_result = play_models_against_each_other(loaded_chessnet_model, loaded_chessvit_model, device=device)\n", + " results.append(game_result.result())\n", + "\n", + "print(\"ChessNet (White) vs ChessViT (Black):\")\n", + "print({result: results.count(result) for result in set(results)})\n" + ], + "execution_count": 29, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ChessNet (White) vs ChessViT (Black):\n", + "{'1/2-1/2': 100}\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "188527c5", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ba5138b9-a47f-4fcd-ab1b-8393649ea24b" + }, + "source": [ + "results = []\n", + "for i in range(100):\n", + " game_result = play_models_against_each_other(loaded_chessvit_model, loaded_chessnet_model, device=device)\n", + " results.append(game_result.result())\n", + "\n", + "print(\"ChessViT (White) vs ChessNet (Black):\")\n", + "print({result: results.count(result) for result in set(results)})" + ], + "execution_count": 30, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ChessViT (White) vs ChessNet (Black):\n", + "{'1/2-1/2': 100}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "game_result = play_models_against_each_other(loaded_chessnet_model, loaded_chessvit_model, device=device)\n", + "game_result\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 411 + }, + "id": "rEyFB14ziH5u", + "outputId": "0deeccde-97e4-46b4-862f-ebc6f07d64ef" + }, + "execution_count": 32, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Board('2q2bn1/p2kp2r/5p2/p1p1nPp1/1rB1Pp1P/2NP4/2P3K1/R1B5 w - - 18 36')" + ], + "image/svg+xml": ". . q . . b n .\np . . k p . . r\n. . . . . p . .\np . p . n P p .\n. r B . P p . P\n. . N P . . . .\n. . P . . . K .\nR . B . . . . ." + }, + "metadata": {}, + "execution_count": 32 + } + ] + }, + { + "cell_type": "code", + "source": [ + "game_result = play_models_against_each_other(loaded_chessvit_model, loaded_chessnet_model, device=device)\n", + "game_result\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 411 + }, + "id": "6S1JgndCiYsV", + "outputId": "419c73a9-659c-414f-9392-38c4da335d22" + }, + "execution_count": 33, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Board('1nb3nr/2ppkppp/4p3/r1b5/p6q/1PPKPNPP/3PBP2/RNB1Q2R b - - 16 21')" + ], + "image/svg+xml": ". n b . . . n r\n. . p p k p p p\n. . . . p . . .\nr . b . . . . .\np . . . . . . q\n. P P K P N P P\n. . . P B P . .\nR N B . Q . . R" + }, + "metadata": {}, + "execution_count": 33 + } + ] + } + ] +} \ No newline at end of file
. . q . . b n .\np . . k p . . r\n. . . . . p . .\np . p . n P p .\n. r B . P p . P\n. . N P . . . .\n. . P . . . K .\nR . B . . . . .
. n b . . . n r\n. . p p k p p p\n. . . . p . . .\nr . b . . . . .\np . . . . . . q\n. P P K P N P P\n. . . P B P . .\nR N B . Q . . R