nshogi WCSC36 モデル定義と精度
昨年のモデルと同様,今年のWCSC36のモデルでも精度評価を行った. 結果は,以下の通りとなった.
- Policy accuracy:
- Value accuracy:
- Policy Entropy:
以下にモデルの定義を示す.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class InputLayer(nn.Module):
"""
Input layer that maps input tokens to embeddings and adds positional embeddings.
"""
def __init__(self, token_dim, emb_dim, seq_length):
"""
Args:
token_dim (int): Dimension of input tokens.
emb_dim (int): Dimension of embeddings.
seq_length (int): Length of input sequence (number of tokens).
"""
super().__init__()
self.emb_dim = emb_dim
self.conv = nn.Conv2d(token_dim, emb_dim, kernel_size=3, bias=False, padding=1)
self.pos_emb = nn.Parameter(torch.zeros(1, seq_length, emb_dim))
nn.init.trunc_normal_(self.pos_emb, std=0.02)
def forward(self, x):
B, C, H, W = x.shape
x = self.conv(x)
# [batch_size, seq_len, token_dim]
x = x.view(B, self.emb_dim, -1).transpose(1, 2).contiguous()
# Positional embedding.
x = x + self.pos_emb
return x
class MultiHeadAttention(nn.Module):
"""
Multi-head attention layer.
"""
def __init__(
self,
emb_dim,
num_heads,
qk_norm=True,
sink_attention=False,
):
"""
Args:
emb_dim (int): Dimension of input embeddings.
num_heads (int): Number of attention heads.
qk_norm (bool): Whether to apply LayerNorm to Q and K.
sink_attention (bool): Whether to use sink attention.
"""
super().__init__()
assert emb_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = emb_dim // num_heads
self.qk_norm = qk_norm
self.sink_attention = sink_attention
if self.qk_norm:
self.q_norm = nn.LayerNorm(self.head_dim)
self.k_norm = nn.LayerNorm(self.head_dim)
if sink_attention:
self.sinks = nn.Parameter(torch.zeros(num_heads))
def forward(self, q, k, v):
B, q_len, emb_dim = q.size()
_, kv_len, _ = k.size()
Q = q.view(B, q_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
K = k.view(B, kv_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
V = v.view(B, kv_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
if self.qk_norm:
Q = self.q_norm(Q)
K = self.k_norm(K)
bias = None
if self.sink_attention:
zero_kv = torch.zeros(
B, self.num_heads, 1, self.head_dim,
device=Q.device, dtype=Q.dtype
)
K = torch.cat([K, zero_kv], dim=2)
V = torch.cat([V, zero_kv], dim=2)
bias = torch.zeros(
B, self.num_heads, q_len, K.size(2),
device=Q.device, dtype=Q.dtype
)
bias[..., -1] = self.sinks.view(1, self.num_heads, 1).to(Q.dtype)
ctx = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=bias,
dropout_p=0.0,
is_causal=False)
# [batch_size, seq_len, emb_dim]
ctx = ctx.transpose(1, 2).contiguous().view(B, q_len, -1)
return ctx
class SelfAttention(nn.Module):
"""
Self-attention layer.
"""
def __init__(
self,
emb_dim,
out_dim,
num_heads,
has_bias,
sink_attention,
):
"""
Args:
emb_dim (int): Dimension of input embeddings.
out_dim (int): Dimension of output embeddings.
num_heads (int): Number of attention heads.
has_bias (bool): Whether to use bias in linear layers.
sink_attention (bool): Whether to use sink attention.
"""
super().__init__()
assert emb_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = emb_dim // num_heads
self.sink_attention = sink_attention
self.qkv_proj = nn.Linear(emb_dim, 3 * emb_dim, bias=has_bias)
self.mha = (
MultiHeadAttention(
emb_dim,
num_heads,
qk_norm=True,
sink_attention=sink_attention,
)
)
self.out = nn.Linear(emb_dim, out_dim, bias=has_bias)
def forward(self, x):
B, seq_len, emb_dim = x.size()
QKV = self.qkv_proj(x)
Q, K, V = QKV.chunk(3, dim=-1)
ctx = self.mha(Q, K, V)
out = self.out(ctx)
return out
class FeedForward(nn.Module):
"""
Feed-forward layer with optional SwiGLU activation.
"""
def __init__(self, emb_dim, hidden_dim, has_bias, use_swiglu):
"""
Args:
emb_dim (int): Dimension of input embeddings.
hidden_dim (int): Dimension of hidden layer.
has_bias (bool): Whether to use bias in linear layers.
use_swiglu (bool): Whether to use SwiGLU activation.
"""
super().__init__()
self.use_swiglu = use_swiglu
inner_dim = hidden_dim * 2 if use_swiglu else hidden_dim
self.linear1 = nn.Linear(emb_dim, inner_dim, bias=has_bias)
self.linear2 = nn.Linear(hidden_dim, emb_dim, bias=has_bias)
def forward(self, x):
x = self.linear1(x)
if self.use_swiglu:
x1, x2 = x.chunk(2, dim=-1)
x = F.silu(x1) * x2
else:
x = F.gelu(x)
x = self.linear2(x)
return x
class EncoderLayer(nn.Module):
"""
Transformer encoder layer.
"""
def __init__(
self,
emb_dim,
num_heads,
hidden_dim,
has_bias,
use_swiglu,
sink_attention,
):
"""
Args:
emb_dim (int): Dimension of input embeddings.
num_heads (int): Number of attention heads.
hidden_dim (int): Dimension of hidden layer in feed-forward network.
has_bias (bool): Whether to use bias in linear layers.
use_swiglu (bool): Whether to use SwiGLU activation in feed-forward network.
sink_attention (bool): Whether to use sink attention.
"""
super().__init__()
assert hidden_dim > 0
self.hidden_dim = hidden_dim
self.self_attn = SelfAttention(emb_dim, emb_dim, num_heads, has_bias, sink_attention)
self.norm1 = nn.LayerNorm(emb_dim)
if hidden_dim > 0:
self.ff = FeedForward(emb_dim, hidden_dim, has_bias, use_swiglu)
self.norm2 = nn.LayerNorm(emb_dim)
def forward(self, x):
residual = x
x = self.norm1(x)
attn_out = self.self_attn(x)
x = residual + attn_out
residual = x
x = self.norm2(x)
ff_out = self.ff(x)
x = residual + ff_out
return x
class Encoder(nn.Module):
"""
Transformer encoder consisting of multiple encoder layers.
"""
def __init__(
self,
num_layers,
emb_dim,
num_heads,
hidden_dim,
has_bias,
use_swiglu,
sink_attention,
):
"""
Args:
num_layers (int): Number of encoder layers.
emb_dim (int): Dimension of input embeddings.
num_heads (int): Number of attention heads.
hidden_dim (int): Dimension of hidden layer in feed-forward network.
has_bias (bool): Whether to use bias in linear layers.
use_swiglu (bool): Whether to use SwiGLU activation in feed-forward network.
sink_attention (bool): Whether to use sink attention.
"""
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(emb_dim, num_heads, hidden_dim, has_bias, use_swiglu, sink_attention)
for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class PolicyHead(nn.Module):
"""
Policy head that outputs move probabilities.
"""
def __init__(self, emb_dim, num_heads):
"""
Args:
emb_dim (int): Dimension of input embeddings.
num_heads (int): Number of attention heads.
"""
super().__init__()
self.conv1 = nn.Conv2d(emb_dim, emb_dim, kernel_size=3, bias=False, padding=1)
self.bn1 = nn.BatchNorm2d(emb_dim)
self.conv2 = nn.Conv2d(emb_dim, 27, kernel_size=3, bias=True, padding=1)
def forward(self, x):
p = self.conv1(x)
p = self.bn1(p)
p = F.relu(p)
p = self.conv2(p)
p = p.view([-1, 81 * 27]).contiguous()
return p
class ValueDrawHead(nn.Module):
"""
Value head that outputs scalar value.
"""
def __init__(self, emb_dim):
"""
Args:
emb_dim (int): Dimension of input embeddings.
"""
super().__init__()
self.conv = nn.Conv2d(emb_dim, 2, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(2)
self.fc1 = nn.Linear(2 * 81, 256, bias=True)
self.fc2 = nn.Linear(256, 2, bias=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)
x = x.view(x.size(0), -1).contiguous()
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
v, d = x.chunk(2, dim=-1)
return F.tanh(v), d
class AttackHead(nn.Module):
"""
Attack head that outputs attack 81-square board map,
from the perspective of the current player and from the opponent.
"""
def __init__(self, emb_dim):
"""
Args:
emb_dim (int): Dimension of input embeddings.
"""
super().__init__()
self.conv1 = nn.Conv2d(emb_dim, emb_dim, kernel_size=3, bias=False, padding=1)
self.bn1 = nn.BatchNorm2d(emb_dim)
self.conv2 = nn.Conv2d(emb_dim, 2, kernel_size=3, bias=True, padding=1)
def forward(self, x):
a = self.conv1(x)
a = self.bn1(a)
a = F.relu(a)
a = self.conv2(a)
return a
class CheckHead(nn.Module):
"""
Check head that outputs scalar of whether the current player is in check.
"""
def __init__(self, emb_dim):
"""
Args:
emb_dim (int): Dimension of input embeddings.
"""
super().__init__()
self.conv = nn.Conv2d(emb_dim, 2, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(2)
self.fc1 = nn.Linear(2 * 81, 256, bias=True)
self.fc2 = nn.Linear(256, 1, bias=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)
x = x.view(x.size(0), -1).contiguous()
x = self.fc1(x)
x = F.relu(x)
c = self.fc2(x)
return c
class DeclarationHead(nn.Module):
"""
Declaration head that outputs scalar of regularized declaration score.
"""
def __init__(self, emb_dim):
"""
Args:
emb_dim (int): Dimension of input embeddings.
"""
super().__init__()
self.conv = nn.Conv2d(emb_dim, 2, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(2)
self.fc1 = nn.Linear(2 * 81, 256, bias=True)
self.fc2 = nn.Linear(256, 2, bias=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)
x = x.view(x.size(0), -1).contiguous()
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
class Transformer(nn.Module):
"""
Encoder only Transformer model for shogi position evaluation.
"""
def __init__(
self,
num_layers,
token_dim,
emb_dim,
num_heads,
hidden_dim,
seq_length,
has_bias=True,
use_swiglu=False,
sink_attention=False,
):
"""
Args:
num_layers (int): Number of encoder layers.
token_dim (int): Dimension of input tokens.
emb_dim (int): Dimension of embeddings.
num_heads (int): Number of attention heads.
hidden_dim (int): Dimension of hidden layer in feed-forward network.
seq_length (int): Length of input sequence (number of tokens).
has_bias (bool): Whether to use bias in linear layers.
use_swiglu (bool): Whether to use SwiGLU activation in feed-forward network.
sink_attention (bool): Whether to use sink attention.
"""
super().__init__()
self.num_layers = num_layers
self.token_dim = token_dim
self.emb_dim = emb_dim
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.has_bias = has_bias
self.use_swiglu = use_swiglu
self.sink_attention = sink_attention
self.input_layer = InputLayer(token_dim, emb_dim, seq_length)
self.encoder = Encoder(num_layers, emb_dim, num_heads, hidden_dim, has_bias, use_swiglu, sink_attention)
self.norm = nn.LayerNorm(emb_dim)
self.policy = PolicyHead(emb_dim, num_heads)
self.value_draw = ValueDrawHead(emb_dim)
self.attack = AttackHead(emb_dim)
self.check = CheckHead(emb_dim)
self.declaration = DeclarationHead(emb_dim)
# Initialization.
def init_fn(module):
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.trunc_normal_(module.weight, mean=0.0, std=0.01)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.apply(init_fn)
def forward(self, x):
B, *_ = x.shape
x = self.input_layer(x)
x = self.encoder(x)
x = self.norm(x)
x = x.transpose(1, 2).contiguous().view(B, self.emb_dim, 9, 9)
p = self.policy(x)
v, d = self.value_draw(x)
if not self.training:
return p, (v + 1.0) / 2.0, torch.sigmoid(d)
a = self.attack(x)
c = self.check(x)
dec = self.declaration(x)
return p, v, d, a, c, dec
if __name__ == "__main__":
import torchinfo
x = torch.zeros([8, 86, 9, 9], dtype=torch.float32)
model = \
Transformer(
num_layers=8,
token_dim=86,
emb_dim=512,
num_heads=16,
hidden_dim=768,
seq_length=81,
has_bias=False,
use_swiglu=True,
sink_attention=False,
)
torchinfo.summary(model, input_data=x)