Skip to content

Commit 972aac3

Browse files
authored
Merge pull request #17 from sgrvinod/0.3.1
0.3.1
2 parents 2acccd1 + aa8e611 commit 972aac3

File tree

10 files changed

+104
-74
lines changed

10 files changed

+104
-74
lines changed

CHANGELOG.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
# Change Log
22

3+
## v0.3.1
4+
5+
### Added
6+
7+
* A [**`pyproject.toml`**
8+
file](https://github.com/sgrvinod/chess-transformers/blob/main/pyproject.toml) has been added in compliance with [PEP 660](https://peps.python.org/pep-0660/). While the inclusion of a `setup.py` file is not deprecated, its use as a command-line tool, such as in the legacy `setup.py develop` method for performing an editable installation is now deprecated.
9+
10+
### Changed
11+
12+
* **`chess_transformers.train.datasets.ChessDataset`** was optimized for large datasets. A list of indices for the data split is no longer maintained or indexed in the dataset.
13+
* The `TRAINING_CHECKPOINT` parameter in each of **`chess_transformers.configs.models`** was set to `None` to reflect the correct conditions for beginning training of a model.
14+
* Dynamic shape tracing is disabled for the compilation of [*CT-ED-45*](https://github.com/sgrvinod/chess-transformers#ct-ed-45) to prevent memory leaks as seen in [#16](https://github.com/sgrvinod/chess-transformers/issues/16).
15+
* References to `torch.cuda.amp.GradScaler(...)` have been replaced by `torch.amp.GradScaler(device="cuda", ...)` following its deprecation.
16+
317
## v0.3.0
418

519
### Added
620

7-
* There are 3 new datasets: [ML23c](https://github.com/sgrvinod/chess-transformers#ml23c), [GC22c](https://github.com/sgrvinod/chess-transformers#gc22c), and [ML23d](https://github.com/sgrvinod/chess-transformers#ml23d).
21+
* There are 3 new datasets: [*ML23c*](https://github.com/sgrvinod/chess-transformers#ml23c), [*GC22c*](https://github.com/sgrvinod/chess-transformers#gc22c), and [*ML23d*](https://github.com/sgrvinod/chess-transformers#ml23d).
822
* A new naming convention for datasets is used. Datasets are now named in the format "[*PGN Fileset*][*Filters*]". For example, *LE1222* is now called [*LE22ct*](https://github.com/sgrvinod/chess-transformers#le22ct), where *LE22* is the name of the PGN fileset from which this dataset was derived, and "*c*", "*t*" are filters for games that ended in checkmates and games that used a specific time control respectively.
923
* [*CT-EFT-85*](https://github.com/sgrvinod/chess-transformers#ct-eft-85) is a new trained model with about 85 million parameters.
1024
* **`chess_transformers.train.utils.get_lr()`** now accepts new arguments, `schedule` and `decay`, to accomodate a new learning rate schedule: exponential decay after warmup.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
<h1 align="center"><i>Chess Transformers</i></h1>
66
<p align="center"><i>Teaching transformers to play chess</i></p>
7-
<p align="center"> <a href="https://github.com/sgrvinod/chess-transformers/releases/tag/v0.3.0"><img alt="Version" src="https://img.shields.io/github/v/tag/sgrvinod/chess-transformers?label=version"></a> <a href="https://github.com/sgrvinod/chess-transformers/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/sgrvinod/chess-transformers?label=license"></a></p>
7+
<p align="center"> <a href="https://github.com/sgrvinod/chess-transformers/releases/tag/v0.3.1"><img alt="Version" src="https://img.shields.io/github/v/tag/sgrvinod/chess-transformers?label=version"></a> <a href="https://github.com/sgrvinod/chess-transformers/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/sgrvinod/chess-transformers?label=license"></a></p>
88
<br>
99

1010
*Chess Transformers* is a library for training transformer models to play chess by learning from human games.

chess_transformers/configs/models/CT-E-20.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@
9494
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
9595
) # folder containing checkpoints
9696
TRAINING_CHECKPOINT = (
97-
NAME + ".pt"
98-
) # path to model checkpoint to resume training, None if none
97+
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
98+
)
9999
CHECKPOINT_AVG_PREFIX = (
100100
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
101101
)

chess_transformers/configs/models/CT-ED-45.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
N_MOVES = 10 # expected maximum length of move sequences in the model, <= MAX_MOVE_SEQUENCE_LENGTH
5151
DISABLE_COMPILATION = False # disable model compilation?
5252
COMPILATION_MODE = "default" # mode of model compilation (see torch.compile())
53-
DYNAMIC_COMPILATION = True # expect tensors with dynamic shapes?
53+
DYNAMIC_COMPILATION = False # expect tensors with dynamic shapes?
5454
SAMPLING_K = 1 # k in top-k sampling model predictions during play
5555
MODEL = ChessTransformer # custom PyTorch model to train
5656

@@ -94,8 +94,8 @@
9494
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
9595
) # folder containing checkpoints
9696
TRAINING_CHECKPOINT = (
97-
NAME + ".pt"
98-
) # path to model checkpoint to resume training, None if none
97+
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
98+
)
9999
CHECKPOINT_AVG_PREFIX = (
100100
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
101101
)

chess_transformers/configs/models/CT-EFT-20.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@
9494
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
9595
) # folder containing checkpoints
9696
TRAINING_CHECKPOINT = (
97-
NAME + ".pt"
98-
) # path to model checkpoint to resume training, None if none
97+
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
98+
)
9999
CHECKPOINT_AVG_PREFIX = (
100100
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
101101
)

chess_transformers/configs/models/CT-EFT-85.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@
9393
CHECKPOINT_FOLDER = str(
9494
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
9595
) # folder containing checkpoints
96-
TRAINING_CHECKPOINT = None # path to model checkpoint to resume training, None if none
96+
TRAINING_CHECKPOINT = (
97+
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
98+
)
9799
AVERAGE_STEPS = {491000, 492500, 494000, 495500, 497000, 498500, 500000}
98100
CHECKPOINT_AVG_PREFIX = (
99101
"step" # prefix to add to checkpoint name when saving checkpoints for averaging

chess_transformers/train/datasets.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,15 @@ def __init__(self, data_folder, h5_file, split, n_moves=None, **unused):
3434
# Open table in H5 file
3535
self.h5_file = tb.open_file(os.path.join(data_folder, h5_file), mode="r")
3636
self.encoded_table = self.h5_file.root.encoded_data
37+
self.split = split
3738

3839
# Create indices
39-
# TODO: optimize by using a start_index and not a list of indices
4040
if split == "train":
41-
self.indices = list(range(0, self.encoded_table.attrs.val_split_index))
41+
self.first_index = 0
4242
elif split == "val":
43-
self.indices = list(
44-
range(
45-
self.encoded_table.attrs.val_split_index, self.encoded_table.nrows
46-
)
47-
)
43+
self.first_index = self.encoded_table.attrs.val_split_index
4844
elif split is None:
49-
self.indices = list(range(0, self.encoded_table.nrows))
45+
self.first_index = 0
5046
else:
5147
raise NotImplementedError
5248

@@ -56,33 +52,41 @@ def __init__(self, data_folder, h5_file, split, n_moves=None, **unused):
5652
if n_moves is not None:
5753
# This is the same as min(MAX_MOVE_SEQUENCE_LENGTH, n_moves)
5854
self.n_moves = min(
59-
len(self.encoded_table[self.indices[0]]["moves"]) - 1, n_moves
55+
len(self.encoded_table[self.first_index]["moves"]) - 1, n_moves
6056
)
6157
else:
62-
self.n_moves = len(self.encoded_table[self.indices[0]]["moves"]) - 1
58+
self.n_moves = len(self.encoded_table[self.first_index]["moves"]) - 1
6359

6460
def __getitem__(self, i):
65-
turns = torch.IntTensor([self.encoded_table[self.indices[i]]["turn"]])
61+
turns = torch.IntTensor([self.encoded_table[self.first_index + i]["turn"]])
6662
white_kingside_castling_rights = torch.IntTensor(
67-
[self.encoded_table[self.indices[i]]["white_kingside_castling_rights"]]
63+
[self.encoded_table[self.first_index + i]["white_kingside_castling_rights"]]
6864
) # (1)
6965
white_queenside_castling_rights = torch.IntTensor(
70-
[self.encoded_table[self.indices[i]]["white_queenside_castling_rights"]]
66+
[
67+
self.encoded_table[self.first_index + i][
68+
"white_queenside_castling_rights"
69+
]
70+
]
7171
) # (1)
7272
black_kingside_castling_rights = torch.IntTensor(
73-
[self.encoded_table[self.indices[i]]["black_kingside_castling_rights"]]
73+
[self.encoded_table[self.first_index + i]["black_kingside_castling_rights"]]
7474
) # (1)
7575
black_queenside_castling_rights = torch.IntTensor(
76-
[self.encoded_table[self.indices[i]]["black_queenside_castling_rights"]]
76+
[
77+
self.encoded_table[self.first_index + i][
78+
"black_queenside_castling_rights"
79+
]
80+
]
7781
) # (1)
7882
board_position = torch.IntTensor(
79-
self.encoded_table[self.indices[i]]["board_position"]
83+
self.encoded_table[self.first_index + i]["board_position"]
8084
) # (64)
8185
moves = torch.LongTensor(
82-
self.encoded_table[self.indices[i]]["moves"][: self.n_moves + 1]
86+
self.encoded_table[self.first_index + i]["moves"][: self.n_moves + 1]
8387
) # (n_moves + 1)
8488
length = torch.LongTensor(
85-
[self.encoded_table[self.indices[i]]["length"]]
89+
[self.encoded_table[self.first_index + i]["length"]]
8690
).clamp(
8791
max=self.n_moves
8892
) # (1), value <= n_moves
@@ -99,7 +103,14 @@ def __getitem__(self, i):
99103
}
100104

101105
def __len__(self):
102-
return len(self.indices)
106+
if self.split == "train":
107+
return self.encoded_table.attrs.val_split_index
108+
elif self.split == "val":
109+
return self.encoded_table.nrows - self.encoded_table.attrs.val_split_index
110+
elif self.split is None:
111+
return self.encoded_table.nrows
112+
else:
113+
raise NotImplementedError
103114

104115

105116
class ChessDatasetFT(Dataset):
@@ -175,12 +186,11 @@ def __len__(self):
175186
elif self.split == "val":
176187
return self.encoded_table.nrows - self.encoded_table.attrs.val_split_index
177188
elif self.split is None:
178-
self.encoded_table.nrows
189+
return self.encoded_table.nrows
179190
else:
180191
raise NotImplementedError
181192

182193

183-
184194
if __name__ == "__main__":
185195
# Get configuration
186196
parser = argparse.ArgumentParser()

chess_transformers/train/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.backends.cudnn as cudnn
66

77
from tqdm import tqdm
8-
from torch.cuda.amp import GradScaler
8+
from torch.amp import GradScaler
99
from torch.utils.data import DataLoader
1010
from torch.utils.tensorboard import SummaryWriter
1111

@@ -96,7 +96,7 @@ def train_model(CONFIG):
9696
criterion = criterion.to(DEVICE)
9797

9898
# AMP scaler
99-
scaler = GradScaler(enabled=CONFIG.USE_AMP)
99+
scaler = GradScaler(device=DEVICE, enabled=CONFIG.USE_AMP)
100100

101101
# Find total epochs to train
102102
epochs = (CONFIG.N_STEPS // (len(train_loader) // CONFIG.BATCHES_PER_STEP)) + 1

pyproject.toml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
[build-system]
2+
requires = ["setuptools >= 64"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "chess-transformers"
7+
version = "0.3.1"
8+
description = "Teaching transformers to play chess."
9+
authors = [{ name = "Sagar Vinodababu", email = "sgrvinod@gmail.com" }]
10+
maintainers = [{ name = "Sagar Vinodababu", email = "sgrvinod@gmail.com" }]
11+
readme = "README.md"
12+
requires-python = ">=3.6.0"
13+
dependencies = [
14+
"beautifulsoup4==4.12.3",
15+
"chess==1.10.0",
16+
"colorama==0.4.5",
17+
"ipython==8.17.2",
18+
"Markdown==3.3.4",
19+
"py_cpuinfo==9.0.0",
20+
"regex==2024.7.24",
21+
"scipy==1.13.1",
22+
"setuptools==69.0.3",
23+
"tables==3.9.2",
24+
"tabulate==0.9.0",
25+
"torch==2.4.0",
26+
"tqdm==4.64.1",
27+
"tensorboard==2.18.0",
28+
]
29+
license = { text = "MIT License" }
30+
keywords = ["transformer", "chess", "pytorch", "deep learning", "chess engine"]
31+
classifiers = [
32+
"Development Status :: 3 - Alpha",
33+
"Intended Audience :: Science/Research",
34+
"License :: OSI Approved :: MIT License",
35+
"Programming Language :: Python :: 3.12",
36+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
37+
]
38+
39+
[project.urls]
40+
homepage = "https://github.com/sgrvinod/chess-transformers"
41+
source = "https://github.com/sgrvinod/chess-transformers"
42+
changelog = "https://github.com/sgrvinod/chess-transformers/blob/main/CHANGELOG.md"
43+
releasenotes = "https://github.com/sgrvinod/chess-transformers/releases"
44+
issues = "https://github.com/sgrvinod/chess-transformers/issues"

setup.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,3 @@
11
from setuptools import setup, find_packages
22

3-
with open("README.md", mode="r", encoding="utf-8") as readme_file:
4-
readme = readme_file.read()
5-
6-
7-
setup(
8-
name="chess-transformers",
9-
version="0.3.0",
10-
author="Sagar Vinodababu",
11-
author_email="sgrvinod@gmail.com",
12-
description="Chess Transformers",
13-
long_description=readme,
14-
long_description_content_type="text/markdown",
15-
license="MIT License",
16-
url="https://github.com/sgrvinod/chess-transformers",
17-
download_url="https://github.com/sgrvinod/chess-transformers",
18-
packages=find_packages(),
19-
python_requires=">=3.6.0",
20-
install_requires=[
21-
"beautifulsoup4==4.12.3",
22-
"chess==1.10.0",
23-
"colorama==0.4.5",
24-
"ipython==8.17.2",
25-
"Markdown==3.3.4",
26-
"py_cpuinfo==9.0.0",
27-
"regex==2024.7.24",
28-
"scipy==1.13.1",
29-
"setuptools==69.0.3",
30-
"tables==3.9.2",
31-
"tabulate==0.9.0",
32-
"torch==2.4.0",
33-
"tqdm==4.64.1",
34-
],
35-
classifiers=[
36-
"Development Status :: 3 - Alpha",
37-
"Intended Audience :: Science/Research",
38-
"License :: OSI Approved :: MIT License",
39-
"Programming Language :: Python :: 3.9",
40-
"Topic :: Scientific/Engineering :: Artificial Intelligence",
41-
],
42-
keywords="transformer networks chess pytorch deep learning",
43-
)
3+
setup(packages=find_packages())

0 commit comments

Comments
 (0)