Custom transformer written in Jax to predict the next chess move in a given sequence :)
-
Clone the repository:
git clone https://github.com/stackviolator/chess-transformer-nnx cd chess-transformer-nnx -
Install the required dependencies:
pip install -r requirements.txt
The train.py script is used to train the Transformer model. Below is an example usage:
python train.py -c configs/transformer_dev.cfg -a configs/training_args.cfg -t src/tokenizer/vocab.json -dThe generate.py script generates chess moves using the trained Transformer model.
-
Prepare the Model: Ensure the trained model is saved in
trained_models/and matches the configuration file inconfigs. -
Run the Script:
python generate.py -m trained_models/dev -o output/temp.txt -d -k 5
-
View Results: The generated move(s) will be printed to the console or saved to a specified output file.
- Found in
src/tokenizer/. - Custom tokenizer, maps a move in SAN notation to an integer. Trained with
train_tokenizer.py
- Found in
src/dataset/. GamesDataset.pyCustom dataset for handling chess moves -- loading and batching for training.
- Found in
src/model/. Transformer.pydefines the Transformer architecture.
- Found in
src/sampler/. Sampler.pyimplements logic for sampling moves during inference.
Unit tests are provided in the tests/ directory. To run the tests:
python -m unittest tests/{file}