Skip to content

Commit 56028b7

Browse files
authored
Merge pull request #4 from sgrvinod/0.2.0
0.2.0
2 parents 3597f1f + 385e58d commit 56028b7

File tree

98 files changed

+219289
-145195
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+219289
-145195
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*.pyc
33
*.egg-info
44

5-
logs
65
.vscode
76
__pycache__
8-
chess_transformers/checkpoints
7+
checkpoints
8+
logs

CHANGELOG.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Change Log
2+
3+
## Unreleased (v0.2.0)
4+
5+
### Added
6+
7+
* **`ChessTransformerEncoderFT`** is an encoder-only transformer that predicts source (*From*) and destination squares (*To*) squares for the next half-move, instead of the half-move in UCI notation.
8+
* [*CT-EFT-20*](https://github.com/sgrvinod/chess-transformers#ct-eft-20) is a new trained model of this type with about 20 million parameters.
9+
* **`ChessDatasetFT`** is a PyTorch dataset class for this model type.
10+
* [**`chess_transformer.data.levels`**](https://github.com/sgrvinod/chess-transformers/blob/main/chess_transformers/data/levels.py) provides a standardized vocabulary (with indices) for oft-used categorical variables. All models and datasets will hereon use this standard vocabulary instead of a dataset-specific vocabulary.
11+
12+
### Changed
13+
14+
* The [*LE1222*](https://github.com/sgrvinod/chess-transformers#le1222) and [*LE1222x*](https://github.com/sgrvinod/chess-transformers#le1222x) datasets no longer have their own vocabularies or vocabulary files. Instead, they use the standard vocabulary from **`chess_transformer.data.levels`**.
15+
* The [*LE1222*](https://github.com/sgrvinod/chess-transformers#le1222) and [*LE1222x*](https://github.com/sgrvinod/chess-transformers#le1222x) datasets have been re-encoded with indices corresponding to the standard vocabulary. Earlier versions or downloads of these datasets are no longer valid for use with this library.
16+
* The row index at which the validation split begins in each dataset is now stored as an attribute of the **`encoded_data`** table in the corresponding H5 file, instead of in a separate JSON file.
17+
* Models [*CT-E-20*](https://github.com/sgrvinod/chess-transformers#ct-e-20) and [*CT-ED-45*](https://github.com/sgrvinod/chess-transformers#ct-ed-45) already trained with a non-standard, dataset-specific vocabulary have been refactored for use with the standard vocabulary. Earlier versions or downloads of these models are no longer valid for use with this library.
18+
* The field **`move_sequence`** in the H5 tables has now been renamed to **`moves`**.
19+
* The field **`move_sequence_length`** in the H5 tables has now been renamed to **`length`**.
20+
* The **`load_assets()`** function has been renamed to **`load_model()`** and it no longer returns a vocabulary — only the model.
21+
* The **`chess_transformers/eval`** folder has been renamed to [**`chess_transformers/evaluate`**](https://github.com/sgrvinod/chess-transformers/tree/main/chess_transformers/evaluate).
22+
* The Python notebook **`lichess_eval.ipynb`** has been converted to a Python script [**`evaluate.py`**](https://github.com/sgrvinod/chess-transformers/blob/main/chess_transformers/evaluate/evaluation.py), which runs much faster for evaluation.
23+
* Fairy Stockfish is now run on 8 threads and with a hash table of size 8 GB during evaluation instead of 1 thread and 16 MB respectively, which makes it a more challenging opponent.
24+
* Evaluation results have been recomputed for [*CT-E-20*](https://github.com/sgrvinod/chess-transformers#ct-e-20) and [*CT-ED-45*](https://github.com/sgrvinod/chess-transformers#ct-ed-45) against this stronger Fairy Stockfish — they naturally fare worse.
25+
26+
### Removed
27+
28+
* The environment variable **`CT_LOGS_FOLDER`** no longer needs to be set before training a model. Training logs will now always be saved to **`chess_transformers/training/logs`**.
29+
* The environment variable **`CT_CHECKPOINT_FOLDER`** no longer needs to be set before training a model. Checkpoints will now always be saved to **`chess_transformers/checkpoints`**.
30+
* The environment variable **`CT_EVAL_GAMES_FOLDER`** no longer needs to be set before evaluating a model. Evaluation games will now always be saved to **`chess_transformers/evaluate/games`**.
31+
32+

README.md

Lines changed: 118 additions & 65 deletions
Large diffs are not rendered by default.

chess_transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"configs",
44
"data",
55
"train",
6-
"eval",
6+
"evaluate",
77
"play",
88
]

chess_transformers/configs/data/LE1222.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
############ Data #############
1111
###############################
1212

13-
DATA_FOLDER = os.path.join(
14-
os.environ.get("CT_DATA_FOLDER"), NAME
13+
DATA_FOLDER = (
14+
os.path.join(os.environ.get("CT_DATA_FOLDER"), NAME)
15+
if os.environ.get("CT_DATA_FOLDER")
16+
else None
1517
) # folder containing all data files
1618
H5_FILE = NAME + ".h5" # H5 file containing data
1719
MAX_MOVE_SEQUENCE_LENGTH = 10 # expected maximum length of move sequences
1820
EXPECTED_ROWS = 12500000 # expected number of rows, approximately, in the data
19-
SPLITS_FILE = "splits.json" # splits file
20-
VOCAB_FILE = "vocabulary.json" # vocabulary file
2121
VAL_SPLIT_FRACTION = 0.925 # marker (% into the data) where the validation split begins

chess_transformers/configs/data/LE1222x.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
############ Data #############
1111
###############################
1212

13-
DATA_FOLDER = os.path.join(
14-
os.environ.get("CT_DATA_FOLDER"), NAME
13+
DATA_FOLDER = (
14+
os.path.join(os.environ.get("CT_DATA_FOLDER"), NAME)
15+
if os.environ.get("CT_DATA_FOLDER")
16+
else None
1517
) # folder containing all data files
1618
H5_FILE = NAME + ".h5" # H5 file containing data
1719
MAX_MOVE_SEQUENCE_LENGTH = 10 # expected maximum length of move sequences
1820
EXPECTED_ROWS = 125000000 # expected number of rows, approximately, in the data
19-
SPLITS_FILE = "splits.json" # splits file
20-
VOCAB_FILE = "vocabulary.json" # vocabulary file
2121
VAL_SPLIT_FRACTION = 0.98 # marker (% into the data) where the validation split begins

chess_transformers/configs/models/CT-E-20.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
import os
21
import torch
2+
import pathlib
33

4+
from chess_transformers.train.utils import get_lr
45
from chess_transformers.configs.data.LE1222 import *
56
from chess_transformers.configs.other.stockfish import *
67
from chess_transformers.train.datasets import ChessDataset
78
from chess_transformers.configs.other.fairy_stockfish import *
8-
from chess_transformers.train.utils import get_lr, get_vocab_sizes
99
from chess_transformers.transformers.criteria import LabelSmoothedCE
10+
from chess_transformers.data.levels import TURN, PIECES, UCI_MOVES, BOOL
1011
from chess_transformers.transformers.models import ChessTransformerEncoder
1112

1213

@@ -30,7 +31,15 @@
3031
############ Model ############
3132
###############################
3233

33-
VOCAB_SIZES = get_vocab_sizes(DATA_FOLDER, VOCAB_FILE) # vocabulary sizes
34+
VOCAB_SIZES = {
35+
"moves": len(UCI_MOVES),
36+
"turn": len(TURN),
37+
"white_kingside_castling_rights": len(BOOL),
38+
"white_queenside_castling_rights": len(BOOL),
39+
"black_kingside_castling_rights": len(BOOL),
40+
"black_queenside_castling_rights": len(BOOL),
41+
"board_position": len(PIECES),
42+
} # vocabulary sizes
3443
D_MODEL = 512 # size of vectors throughout the transformer model
3544
N_HEADS = 8 # number of heads in the multi-head attention
3645
D_QUERIES = 64 # size of query vectors (and also the size of the key vectors) in the multi-head attention
@@ -67,18 +76,16 @@
6776
USE_AMP = True # use automatic mixed precision training?
6877
CRITERION = LabelSmoothedCE # training criterion (loss)
6978
OPTIMIZER = torch.optim.Adam # optimizer
70-
LOGS_DIR = (
71-
os.path.join(os.environ.get("CT_LOGS_FOLDER"), NAME)
72-
if os.environ.get("CT_LOGS_FOLDER")
73-
else None
79+
LOGS_FOLDER = str(
80+
pathlib.Path(__file__).parent.parent.parent.resolve() / "train" / "logs" / NAME
7481
) # logs folder
7582

7683
###############################
7784
######### Checkpoints #########
7885
###############################
7986

80-
CHECKPOINT_FOLDER = os.path.join(
81-
os.environ.get("CT_CHECKPOINTS_FOLDER"), NAME
87+
CHECKPOINT_FOLDER = str(
88+
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
8289
) # folder containing checkpoints
8390
TRAINING_CHECKPOINT = (
8491
NAME + ".pt"
@@ -93,18 +100,13 @@
93100
"averaged_" + NAME + ".pt"
94101
) # final checkpoint to be used for eval/inference
95102
FINAL_CHECKPOINT_GDID = (
96-
"1CnLD4tBTsCJVEvs8wocaSShpppiCWy96" # File ID on Google Drive for downloading
97-
)
98-
VOCABULARY_GDID = (
99-
"1Zpw9BR5YbiWfV7TZpa03DWId41PnzvS3" # File ID on Google Drive for download
103+
"18Er4LbdujG-qiPPoqORvMQVcsiFerqY4" # Google Drive ID for download
100104
)
101105

102106
################################
103107
########## Evaluation ##########
104108
################################
105109

106-
EVAL_GAMES_FOLDER = (
107-
os.path.join(os.environ.get("CT_EVAL_GAMES_FOLDER"), NAME)
108-
if os.environ.get("CT_EVAL_GAMES_FOLDER")
109-
else None
110-
) # folder where games against Stockfish are saved in PGN files
110+
EVAL_GAMES_FOLDER = str(
111+
pathlib.Path(__file__).parent.parent.parent.resolve() / "eval" / "games" / NAME
112+
) # folder where evaluation games are saved in PGN files

chess_transformers/configs/models/CT-ED-45.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import os
21
import torch
2+
import pathlib
33

4+
from chess_transformers.train.utils import get_lr
45
from chess_transformers.configs.data.LE1222 import *
56
from chess_transformers.configs.other.stockfish import *
67
from chess_transformers.train.datasets import ChessDataset
78
from chess_transformers.configs.other.fairy_stockfish import *
8-
from chess_transformers.train.utils import get_lr, get_vocab_sizes
99
from chess_transformers.transformers.models import ChessTransformer
1010
from chess_transformers.transformers.criteria import LabelSmoothedCE
11+
from chess_transformers.data.levels import TURN, PIECES, UCI_MOVES, BOOL
1112

1213

1314
###############################
@@ -30,7 +31,15 @@
3031
############ Model ############
3132
###############################
3233

33-
VOCAB_SIZES = get_vocab_sizes(DATA_FOLDER, VOCAB_FILE) # vocabulary sizes
34+
VOCAB_SIZES = {
35+
"moves": len(UCI_MOVES),
36+
"turn": len(TURN),
37+
"white_kingside_castling_rights": len(BOOL),
38+
"white_queenside_castling_rights": len(BOOL),
39+
"black_kingside_castling_rights": len(BOOL),
40+
"black_queenside_castling_rights": len(BOOL),
41+
"board_position": len(PIECES),
42+
} # vocabulary sizes
3443
D_MODEL = 512 # size of vectors throughout the transformer model
3544
N_HEADS = 8 # number of heads in the multi-head attention
3645
D_QUERIES = 64 # size of query vectors (and also the size of the key vectors) in the multi-head attention
@@ -67,18 +76,16 @@
6776
USE_AMP = True # use automatic mixed precision training?
6877
CRITERION = LabelSmoothedCE # training criterion (loss)
6978
OPTIMIZER = torch.optim.Adam # optimizer
70-
LOGS_DIR = (
71-
os.path.join(os.environ.get("CT_LOGS_FOLDER"), NAME)
72-
if os.environ.get("CT_LOGS_FOLDER")
73-
else None
79+
LOGS_FOLDER = str(
80+
pathlib.Path(__file__).parent.parent.parent.resolve() / "train" / "logs" / NAME
7481
) # logs folder
7582

7683
###############################
7784
######### Checkpoints #########
7885
###############################
7986

80-
CHECKPOINT_FOLDER = os.path.join(
81-
os.environ.get("CT_CHECKPOINTS_FOLDER"), NAME
87+
CHECKPOINT_FOLDER = str(
88+
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
8289
) # folder containing checkpoints
8390
TRAINING_CHECKPOINT = (
8491
NAME + ".pt"
@@ -93,18 +100,13 @@
93100
"averaged_" + NAME + ".pt"
94101
) # final checkpoint to be used for eval/inference
95102
FINAL_CHECKPOINT_GDID = (
96-
"1A-IOMBkJ1mZJmAVGhBZNmPI5E94lsWYb" # File ID on Google Drive for downloading
97-
)
98-
VOCABULARY_GDID = (
99-
"1Vf0BjLN8iN7qE3FaT_FoRRPg9Lw8IDvH" # File ID on Google Drive for download
103+
"1zasRpPmZQVtAqumet9XMy1FBpmxxiM4L" # Google Drive ID for download
100104
)
101105

102106
################################
103107
########## Evaluation ##########
104108
################################
105109

106-
EVAL_GAMES_FOLDER = (
107-
os.path.join(os.environ.get("CT_EVAL_GAMES_FOLDER"), NAME)
108-
if os.environ.get("CT_EVAL_GAMES_FOLDER")
109-
else None
110-
) # folder where games against Stockfish are saved in PGN files
110+
EVAL_GAMES_FOLDER = str(
111+
pathlib.Path(__file__).parent.parent.parent.resolve() / "eval" / "games" / NAME
112+
) # folder where evaluation games are saved in PGN files
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import pathlib
3+
4+
from chess_transformers.train.utils import get_lr
5+
from chess_transformers.configs.data.LE1222 import *
6+
from chess_transformers.configs.other.stockfish import *
7+
from chess_transformers.train.datasets import ChessDatasetFT
8+
from chess_transformers.configs.other.fairy_stockfish import *
9+
from chess_transformers.transformers.criteria import LabelSmoothedCE
10+
from chess_transformers.data.levels import TURN, PIECES, UCI_MOVES, BOOL
11+
from chess_transformers.transformers.models import ChessTransformerEncoderFT
12+
13+
14+
###############################
15+
############ Name #############
16+
###############################
17+
18+
NAME = "CT-EFT-20" # name and identifier for this configuration
19+
20+
###############################
21+
######### Dataloading #########
22+
###############################
23+
24+
DATASET = ChessDatasetFT # custom PyTorch dataset
25+
BATCH_SIZE = 512 # batch size
26+
NUM_WORKERS = 8 # number of workers to use for dataloading
27+
PREFETCH_FACTOR = 2 # number of batches to prefetch per worker
28+
PIN_MEMORY = False # pin to GPU memory when dataloading?
29+
30+
###############################
31+
############ Model ############
32+
###############################
33+
34+
VOCAB_SIZES = {
35+
"moves": len(UCI_MOVES),
36+
"turn": len(TURN),
37+
"white_kingside_castling_rights": len(BOOL),
38+
"white_queenside_castling_rights": len(BOOL),
39+
"black_kingside_castling_rights": len(BOOL),
40+
"black_queenside_castling_rights": len(BOOL),
41+
"board_position": len(PIECES),
42+
} # vocabulary sizes
43+
D_MODEL = 512 # size of vectors throughout the transformer model
44+
N_HEADS = 8 # number of heads in the multi-head attention
45+
D_QUERIES = 64 # size of query vectors (and also the size of the key vectors) in the multi-head attention
46+
D_VALUES = 64 # size of value vectors in the multi-head attention
47+
D_INNER = 2048 # an intermediate size in the position-wise FC
48+
N_LAYERS = 6 # number of layers in the Encoder and Decoder
49+
DROPOUT = 0.1 # dropout probability
50+
N_MOVES = 1 # expected maximum length of move sequences in the model, <= MAX_MOVE_SEQUENCE_LENGTH
51+
DISABLE_COMPILATION = False # disable model compilation?
52+
COMPILATION_MODE = "default" # mode of model compilation (see torch.compile())
53+
DYNAMIC_COMPILATION = True # expect tensors with dynamic shapes?
54+
SAMPLING_K = 1 # k in top-k sampling model predictions during play
55+
MODEL = ChessTransformerEncoderFT # custom PyTorch model to train
56+
57+
###############################
58+
########### Training ##########
59+
###############################
60+
61+
BATCHES_PER_STEP = (
62+
4 # perform a training step, i.e. update parameters, once every so many batches
63+
)
64+
PRINT_FREQUENCY = 1 # print status once every so many steps
65+
N_STEPS = 100000 # number of training steps
66+
WARMUP_STEPS = 8000 # number of warmup steps where learning rate is increased linearly; twice the value in the paper, as in the official transformer repo.
67+
STEP = 1 # the step number, start from 1 to prevent math error in the next line
68+
LR = get_lr(
69+
step=STEP, d_model=D_MODEL, warmup_steps=WARMUP_STEPS
70+
) # see utils.py for learning rate schedule; twice the schedule in the paper, as in the official transformer repo.
71+
START_EPOCH = 0 # start at this epoch
72+
BETAS = (0.9, 0.98) # beta coefficients in the Adam optimizer
73+
EPSILON = 1e-9 # epsilon term in the Adam optimizer
74+
LABEL_SMOOTHING = 0.1 # label smoothing co-efficient in the Cross Entropy loss
75+
BOARD_STATUS_LENGTH = 70 # total length of input sequence
76+
USE_AMP = True # use automatic mixed precision training?
77+
CRITERION = LabelSmoothedCE # training criterion (loss)
78+
OPTIMIZER = torch.optim.Adam # optimizer
79+
LOGS_FOLDER = str(
80+
pathlib.Path(__file__).parent.parent.parent.resolve() / "train" / "logs" / NAME
81+
) # logs folder
82+
83+
###############################
84+
######### Checkpoints #########
85+
###############################
86+
87+
CHECKPOINT_FOLDER = str(
88+
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
89+
) # folder containing checkpoints
90+
TRAINING_CHECKPOINT = (
91+
NAME + ".pt"
92+
) # path to model checkpoint to resume training, None if none
93+
CHECKPOINT_AVG_PREFIX = (
94+
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
95+
)
96+
CHECKPOINT_AVG_SUFFIX = (
97+
".pt" # checkpoint end string to match checkpoints saved for averaging
98+
)
99+
FINAL_CHECKPOINT = (
100+
"averaged_" + NAME + ".pt"
101+
) # final checkpoint to be used for eval/inference
102+
FINAL_CHECKPOINT_GDID = (
103+
"1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD" # Google Drive ID for download
104+
)
105+
106+
107+
################################
108+
########## Evaluation ##########
109+
################################
110+
111+
EVAL_GAMES_FOLDER = str(
112+
pathlib.Path(__file__).parent.parent.parent.resolve() / "eval" / "games" / NAME
113+
) # folder where evaluation games are saved in PGN files
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__all__ = ["CT-E-19", "CT-ED-45"]
1+
__all__ = ["CT-E-19", "CT-ED-45", "CT-EFT-20"]

0 commit comments

Comments
 (0)