nshogi WCSC36 モデル定義と精度


昨年のモデルと同様,今年のWCSC36のモデルでも精度評価を行った. 結果は,以下の通りとなった.

  • Policy accuracy: 50.75%50.75\%
  • Value accuracy: 75.19%75.19\%
  • Policy Entropy: 1.29581.2958

以下にモデルの定義を示す.

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

class InputLayer(nn.Module):
    """
    Input layer that maps input tokens to embeddings and adds positional embeddings.
    """

    def __init__(self, token_dim, emb_dim, seq_length):
        """
        Args:
            token_dim (int): Dimension of input tokens.
            emb_dim (int): Dimension of embeddings.
            seq_length (int): Length of input sequence (number of tokens).
        """

        super().__init__()

        self.emb_dim = emb_dim

        self.conv = nn.Conv2d(token_dim, emb_dim, kernel_size=3, bias=False, padding=1)
        self.pos_emb = nn.Parameter(torch.zeros(1, seq_length, emb_dim))

        nn.init.trunc_normal_(self.pos_emb, std=0.02)

    def forward(self, x):
        B, C, H, W = x.shape

        x = self.conv(x)

        # [batch_size, seq_len, token_dim]
        x = x.view(B, self.emb_dim, -1).transpose(1, 2).contiguous()

        # Positional embedding.
        x = x + self.pos_emb

        return x

class MultiHeadAttention(nn.Module):
    """
    Multi-head attention layer.
    """

    def __init__(
        self,
        emb_dim,
        num_heads,
        qk_norm=True,
        sink_attention=False,
    ):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
            num_heads (int): Number of attention heads.
            qk_norm (bool): Whether to apply LayerNorm to Q and K.
            sink_attention (bool): Whether to use sink attention.
        """

        super().__init__()

        assert emb_dim % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads
        self.qk_norm = qk_norm
        self.sink_attention = sink_attention

        if self.qk_norm:
            self.q_norm = nn.LayerNorm(self.head_dim)
            self.k_norm = nn.LayerNorm(self.head_dim)

        if sink_attention:
            self.sinks = nn.Parameter(torch.zeros(num_heads))

    def forward(self, q, k, v):
        B, q_len, emb_dim = q.size()
        _, kv_len, _ = k.size()

        Q = q.view(B, q_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        K = k.view(B, kv_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        V = v.view(B, kv_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        if self.qk_norm:
            Q = self.q_norm(Q)
            K = self.k_norm(K)

        bias = None
        if self.sink_attention:
            zero_kv = torch.zeros(
                B, self.num_heads, 1, self.head_dim,
                device=Q.device, dtype=Q.dtype
            )
            K = torch.cat([K, zero_kv], dim=2)
            V = torch.cat([V, zero_kv], dim=2)

            bias = torch.zeros(
                B, self.num_heads, q_len, K.size(2),
                device=Q.device, dtype=Q.dtype
            )
            bias[..., -1] = self.sinks.view(1, self.num_heads, 1).to(Q.dtype)

        ctx = F.scaled_dot_product_attention(
                Q, K, V,
                attn_mask=bias,
                dropout_p=0.0,
                is_causal=False)

        # [batch_size, seq_len, emb_dim]
        ctx = ctx.transpose(1, 2).contiguous().view(B, q_len, -1)

        return ctx

class SelfAttention(nn.Module):
    """
    Self-attention layer.
    """

    def __init__(
        self,
        emb_dim,
        out_dim,
        num_heads,
        has_bias,
        sink_attention,
    ):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
            out_dim (int): Dimension of output embeddings.
            num_heads (int): Number of attention heads.
            has_bias (bool): Whether to use bias in linear layers.
            sink_attention (bool): Whether to use sink attention.
        """

        super().__init__()

        assert emb_dim % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads
        self.sink_attention = sink_attention

        self.qkv_proj = nn.Linear(emb_dim, 3 * emb_dim, bias=has_bias)

        self.mha = (
            MultiHeadAttention(
                emb_dim,
                num_heads,
                qk_norm=True,
                sink_attention=sink_attention,
            )
        )

        self.out = nn.Linear(emb_dim, out_dim, bias=has_bias)

    def forward(self, x):
        B, seq_len, emb_dim = x.size()

        QKV = self.qkv_proj(x)
        Q, K, V = QKV.chunk(3, dim=-1)

        ctx = self.mha(Q, K, V)
        out = self.out(ctx)

        return out

class FeedForward(nn.Module):
    """
    Feed-forward layer with optional SwiGLU activation.
    """

    def __init__(self, emb_dim, hidden_dim, has_bias, use_swiglu):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
            hidden_dim (int): Dimension of hidden layer.
            has_bias (bool): Whether to use bias in linear layers.
            use_swiglu (bool): Whether to use SwiGLU activation.
        """

        super().__init__()

        self.use_swiglu = use_swiglu

        inner_dim = hidden_dim * 2 if use_swiglu else hidden_dim

        self.linear1 = nn.Linear(emb_dim, inner_dim, bias=has_bias)
        self.linear2 = nn.Linear(hidden_dim, emb_dim, bias=has_bias)

    def forward(self, x):
        x = self.linear1(x)

        if self.use_swiglu:
            x1, x2 = x.chunk(2, dim=-1)
            x = F.silu(x1) * x2
        else:
            x = F.gelu(x)

        x = self.linear2(x)

        return x

class EncoderLayer(nn.Module):
    """
    Transformer encoder layer.
    """

    def __init__(
        self,
        emb_dim,
        num_heads,
        hidden_dim,
        has_bias,
        use_swiglu,
        sink_attention,
    ):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
            num_heads (int): Number of attention heads.
            hidden_dim (int): Dimension of hidden layer in feed-forward network.
            has_bias (bool): Whether to use bias in linear layers.
            use_swiglu (bool): Whether to use SwiGLU activation in feed-forward network.
            sink_attention (bool): Whether to use sink attention.
        """

        super().__init__()

        assert hidden_dim > 0
        self.hidden_dim = hidden_dim

        self.self_attn = SelfAttention(emb_dim, emb_dim, num_heads, has_bias, sink_attention)
        self.norm1 = nn.LayerNorm(emb_dim)

        if hidden_dim > 0:
            self.ff = FeedForward(emb_dim, hidden_dim, has_bias, use_swiglu)

        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        residual = x
        x = self.norm1(x)

        attn_out = self.self_attn(x)
        x = residual + attn_out

        residual = x
        x = self.norm2(x)
        ff_out = self.ff(x)
        x = residual + ff_out

        return x

class Encoder(nn.Module):
    """
    Transformer encoder consisting of multiple encoder layers.
    """

    def __init__(
        self,
        num_layers,
        emb_dim,
        num_heads,
        hidden_dim,
        has_bias,
        use_swiglu,
        sink_attention,
    ):
        """
        Args:
            num_layers (int): Number of encoder layers.
            emb_dim (int): Dimension of input embeddings.
            num_heads (int): Number of attention heads.
            hidden_dim (int): Dimension of hidden layer in feed-forward network.
            has_bias (bool): Whether to use bias in linear layers.
            use_swiglu (bool): Whether to use SwiGLU activation in feed-forward network.
            sink_attention (bool): Whether to use sink attention.
        """

        super().__init__()

        self.layers = nn.ModuleList([
            EncoderLayer(emb_dim, num_heads, hidden_dim, has_bias, use_swiglu, sink_attention)
            for _ in range(num_layers)
            ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        return x

class PolicyHead(nn.Module):
    """
    Policy head that outputs move probabilities.
    """

    def __init__(self, emb_dim, num_heads):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
            num_heads (int): Number of attention heads.
        """

        super().__init__()
        self.conv1 = nn.Conv2d(emb_dim, emb_dim, kernel_size=3, bias=False, padding=1)
        self.bn1 = nn.BatchNorm2d(emb_dim)
        self.conv2 = nn.Conv2d(emb_dim, 27, kernel_size=3, bias=True, padding=1)

    def forward(self, x):
        p = self.conv1(x)
        p = self.bn1(p)
        p = F.relu(p)
        p = self.conv2(p)
        p = p.view([-1, 81 * 27]).contiguous()
        return p

class ValueDrawHead(nn.Module):
    """
    Value head that outputs scalar value.
    """

    def __init__(self, emb_dim):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
        """

        super().__init__()
        self.conv = nn.Conv2d(emb_dim, 2, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(2)
        self.fc1 = nn.Linear(2 * 81, 256, bias=True)
        self.fc2 = nn.Linear(256, 2, bias=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x = x.view(x.size(0), -1).contiguous()
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        v, d = x.chunk(2, dim=-1)
        return F.tanh(v), d

class AttackHead(nn.Module):
    """
    Attack head that outputs attack 81-square board map,
    from the perspective of the current player and from the opponent.
    """

    def __init__(self, emb_dim):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
        """

        super().__init__()
        self.conv1 = nn.Conv2d(emb_dim, emb_dim, kernel_size=3, bias=False, padding=1)
        self.bn1 = nn.BatchNorm2d(emb_dim)
        self.conv2 = nn.Conv2d(emb_dim, 2, kernel_size=3, bias=True, padding=1)

    def forward(self, x):
        a = self.conv1(x)
        a = self.bn1(a)
        a = F.relu(a)
        a = self.conv2(a)
        return a

class CheckHead(nn.Module):
    """
    Check head that outputs scalar of whether the current player is in check.
    """

    def __init__(self, emb_dim):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
        """

        super().__init__()
        self.conv = nn.Conv2d(emb_dim, 2, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(2)
        self.fc1 = nn.Linear(2 * 81, 256, bias=True)
        self.fc2 = nn.Linear(256, 1, bias=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x = x.view(x.size(0), -1).contiguous()
        x = self.fc1(x)
        x = F.relu(x)
        c = self.fc2(x)
        return c

class DeclarationHead(nn.Module):
    """
    Declaration head that outputs scalar of regularized declaration score.
    """

    def __init__(self, emb_dim):
        """
        Args:
            emb_dim (int): Dimension of input embeddings.
        """

        super().__init__()
        self.conv = nn.Conv2d(emb_dim, 2, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(2)
        self.fc1 = nn.Linear(2 * 81, 256, bias=True)
        self.fc2 = nn.Linear(256, 2, bias=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x = x.view(x.size(0), -1).contiguous()
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

class Transformer(nn.Module):
    """
    Encoder only Transformer model for shogi position evaluation.
    """

    def __init__(
        self,
        num_layers,
        token_dim,
        emb_dim,
        num_heads,
        hidden_dim,
        seq_length,
        has_bias=True,
        use_swiglu=False,
        sink_attention=False,
    ):
        """
        Args:
            num_layers (int): Number of encoder layers.
            token_dim (int): Dimension of input tokens.
            emb_dim (int): Dimension of embeddings.
            num_heads (int): Number of attention heads.
            hidden_dim (int): Dimension of hidden layer in feed-forward network.
            seq_length (int): Length of input sequence (number of tokens).
            has_bias (bool): Whether to use bias in linear layers.
            use_swiglu (bool): Whether to use SwiGLU activation in feed-forward network.
            sink_attention (bool): Whether to use sink attention.
        """

        super().__init__()

        self.num_layers = num_layers
        self.token_dim = token_dim
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.has_bias = has_bias
        self.use_swiglu = use_swiglu
        self.sink_attention = sink_attention

        self.input_layer = InputLayer(token_dim, emb_dim, seq_length)
        self.encoder = Encoder(num_layers, emb_dim, num_heads, hidden_dim, has_bias, use_swiglu, sink_attention)

        self.norm = nn.LayerNorm(emb_dim)

        self.policy = PolicyHead(emb_dim, num_heads)
        self.value_draw = ValueDrawHead(emb_dim)
        self.attack = AttackHead(emb_dim)
        self.check = CheckHead(emb_dim)
        self.declaration = DeclarationHead(emb_dim)

        # Initialization.
        def init_fn(module):
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

            elif isinstance(module, nn.Conv2d):
                nn.init.trunc_normal_(module.weight, mean=0.0, std=0.01)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        self.apply(init_fn)

    def forward(self, x):
        B, *_ = x.shape

        x = self.input_layer(x)
        x = self.encoder(x)

        x = self.norm(x)

        x = x.transpose(1, 2).contiguous().view(B, self.emb_dim, 9, 9)

        p = self.policy(x)
        v, d = self.value_draw(x)

        if not self.training:
            return p, (v + 1.0) / 2.0, torch.sigmoid(d)

        a = self.attack(x)
        c = self.check(x)
        dec = self.declaration(x)

        return p, v, d, a, c, dec

if __name__ == "__main__":
    import torchinfo

    x = torch.zeros([8, 86, 9, 9], dtype=torch.float32)

    model = \
        Transformer(
            num_layers=8,
            token_dim=86,
            emb_dim=512,
            num_heads=16,
            hidden_dim=768,
            seq_length=81,
            has_bias=False,
            use_swiglu=True,
            sink_attention=False,
        )

    torchinfo.summary(model, input_data=x)