Let’s get into a bit more technical detail on how our Connect 4-playing model will be set up, and how a basic game loop works. Throughout all code samples we’ll always assume the standard PyTorch imports:

import torch
import torch.nn as nn
import torch.nn.functional as F

Board state

The current board state will be represented by a 6x7 PyTorch int8 tensor, initially filled with zeros.

    board = torch.zeros((ROWS, COLS), dtype=torch.int8, device=DEVICE)

The board is ordered such that board[0, :] is the top row. A non-empty cell is represented by +1 or -1. To simplify things, we always represent the player whose move it currently is by +1, and the opponent by -1. This way we don’t need any separate state to keep track of whose move it is. After a move has been made, we simply flip the board by doing

    board = -board

Model protocol

Any model that wants to play a game of Connect 4 will have to follow a simple protocol: it takes in the current board state as described above and outputs a float32 tensor of seven numbers which represent the probability of playing a move in each of the seven columns. A model which takes in a state and outputs a recommended action is called a policy model in RL parlance.

We will keep the model output in raw logits, that is, arbitrary numbers between minus and plus infinity, with no activation function applied to them. To convert these to probabilities, we use the softmax operator which applies an exponential function to each number and then normalizes them to add up to 1. Finally, we choose a move by sampling from the resulting random distribution \((p_1, p_2, \ldots, p_7)\) over the seven columns.

There is one slight complication: once a column is full, i.e., board[0, c] != 0, it’s no longer valid to play a move there, and we can’t rely on the model to always output zero probabilities for full columns. So after obtaining the raw logits, we set any of them corresponding to illegal moves to minus infinity, which will result in a zero probability for that column.

So a simple function to sample a valid move from a model could look like this:

def sample_move(model, board: torch.Tensor) -> int:
    """Sample a random move using the model's output logits."""
    logits = model(board)                       # Get raw logits from model
    illegal_moves = torch.where(board[0, :] == 0, 0.0, -torch.inf)
    logits += illegal_moves                     # Mask out illegal moves
    probs = F.softmax(logits, dim=-1)           # Convert logits to probabilities
    return torch.multinomial(probs, 1).item()   # Sample the distribution

For performance and batching, we extend the model protocol to also allow an entire batch of boards for evaluation, so with a batch size B it will then map tensors of size

    (B, 6, 7) -> (B, 7)

Probably the simplest possible model you could write is one that just makes random moves:

class RandomPlayer(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 2:   # Single board state: (6, 7)
            return torch.zeros((7,))
        else:             # A batch of board states: (B, 6, 7)
            return torch.zeros((x.size(0), 7),)   # -> (B, 7)

That’s all we need: a vector of constant logits is mapped to constant probabilities by softmax, and we don’t need to worry about illegal moves either because sample_move takes care of that.

Playing a game

There is one other function we need, make_move_and_check(board, move), which returns the new board state after making a move in the given column, as well as a flag indicating whether the move resulted in a win. It’s pretty straightforward but a bit tedious to write because of the various directions you have to check for a winning row, so I’m not reproducing it here, but you can check it out in the repo.

With that, we have everything to write a complete game loop:

def play(model1, model2):
    """Have two models play against each other. Returns the winner (1 or 2)
    or 0 for a draw."""
    model1.eval()
    model2.eval()
    winner = 1
    with torch.no_grad():
        board = torch.zeros((6, 7), dtype=torch.int8, device=DEVICE)

        while True:
            # Get move from the model, play it and check for a win
            move = sample_move(model1, board, output_probs=output)
            board, win = make_move_and_check(board, move)

            if win:
                return winner

            elif torch.all(board[0, :] != 0):  # Check if the top row is full   
                return 0    # Draw

            board = -board                  # Flip the board for the other player
            winner = 3 - winner             # Alternate between 1 and 2
            model1, model2 = model2, model1 # Swap models for the next turn

There is another terminal condition we have to check for here: if there was no win but the entire board got filled up, which we check for by simply examining the top row, the game ended in a draw.

We now have the basic mechanism for self-play set up and can generate any number of games by having two models play against each other. In practice we play an entire batch of games at once because it’s significantly more efficient; it complicates the basic game loop above a bit but it’s still pretty straightforward.

Another simple extension of this function we will need for training is that it should be able to return not just the final result, but a full list of all encountered board states, the moves that the model made in those states, and a vector of “returns” which indicates if the model won or lost the game. So after playing one or a batch of games, we’ll get three tensors

    all_states:   (N, 6, 7)    dtype=torch.int8
    all_moves:    (N,)         dtype=torch.long
    all_returns:  (N,)         dtype=torch.float32

Here N is the total number of moves the model made across all games played. We’ll talk about the exact form the returns take in the next post, but for now just think of them as being +1 for a win, 0 for a draw, and -1 for a loss.