|
| 1 | +import torch |
| 2 | +import pathlib |
| 3 | + |
| 4 | +from chess_transformers.train.utils import get_lr |
| 5 | +from chess_transformers.configs.data.LE1222 import * |
| 6 | +from chess_transformers.configs.other.stockfish import * |
| 7 | +from chess_transformers.train.datasets import ChessDatasetFT |
| 8 | +from chess_transformers.configs.other.fairy_stockfish import * |
| 9 | +from chess_transformers.transformers.criteria import LabelSmoothedCE |
| 10 | +from chess_transformers.data.levels import TURN, PIECES, UCI_MOVES, BOOL |
| 11 | +from chess_transformers.transformers.models import ChessTransformerEncoderFT |
| 12 | + |
| 13 | + |
| 14 | +############################### |
| 15 | +############ Name ############# |
| 16 | +############################### |
| 17 | + |
| 18 | +NAME = "CT-EFT-20" # name and identifier for this configuration |
| 19 | + |
| 20 | +############################### |
| 21 | +######### Dataloading ######### |
| 22 | +############################### |
| 23 | + |
| 24 | +DATASET = ChessDatasetFT # custom PyTorch dataset |
| 25 | +BATCH_SIZE = 512 # batch size |
| 26 | +NUM_WORKERS = 8 # number of workers to use for dataloading |
| 27 | +PREFETCH_FACTOR = 2 # number of batches to prefetch per worker |
| 28 | +PIN_MEMORY = False # pin to GPU memory when dataloading? |
| 29 | + |
| 30 | +############################### |
| 31 | +############ Model ############ |
| 32 | +############################### |
| 33 | + |
| 34 | +VOCAB_SIZES = { |
| 35 | + "moves": len(UCI_MOVES), |
| 36 | + "turn": len(TURN), |
| 37 | + "white_kingside_castling_rights": len(BOOL), |
| 38 | + "white_queenside_castling_rights": len(BOOL), |
| 39 | + "black_kingside_castling_rights": len(BOOL), |
| 40 | + "black_queenside_castling_rights": len(BOOL), |
| 41 | + "board_position": len(PIECES), |
| 42 | +} # vocabulary sizes |
| 43 | +D_MODEL = 512 # size of vectors throughout the transformer model |
| 44 | +N_HEADS = 8 # number of heads in the multi-head attention |
| 45 | +D_QUERIES = 64 # size of query vectors (and also the size of the key vectors) in the multi-head attention |
| 46 | +D_VALUES = 64 # size of value vectors in the multi-head attention |
| 47 | +D_INNER = 2048 # an intermediate size in the position-wise FC |
| 48 | +N_LAYERS = 6 # number of layers in the Encoder and Decoder |
| 49 | +DROPOUT = 0.1 # dropout probability |
| 50 | +N_MOVES = 1 # expected maximum length of move sequences in the model, <= MAX_MOVE_SEQUENCE_LENGTH |
| 51 | +DISABLE_COMPILATION = False # disable model compilation? |
| 52 | +COMPILATION_MODE = "default" # mode of model compilation (see torch.compile()) |
| 53 | +DYNAMIC_COMPILATION = True # expect tensors with dynamic shapes? |
| 54 | +SAMPLING_K = 1 # k in top-k sampling model predictions during play |
| 55 | +MODEL = ChessTransformerEncoderFT # custom PyTorch model to train |
| 56 | + |
| 57 | +############################### |
| 58 | +########### Training ########## |
| 59 | +############################### |
| 60 | + |
| 61 | +BATCHES_PER_STEP = ( |
| 62 | + 4 # perform a training step, i.e. update parameters, once every so many batches |
| 63 | +) |
| 64 | +PRINT_FREQUENCY = 1 # print status once every so many steps |
| 65 | +N_STEPS = 100000 # number of training steps |
| 66 | +WARMUP_STEPS = 8000 # number of warmup steps where learning rate is increased linearly; twice the value in the paper, as in the official transformer repo. |
| 67 | +STEP = 1 # the step number, start from 1 to prevent math error in the next line |
| 68 | +LR = get_lr( |
| 69 | + step=STEP, d_model=D_MODEL, warmup_steps=WARMUP_STEPS |
| 70 | +) # see utils.py for learning rate schedule; twice the schedule in the paper, as in the official transformer repo. |
| 71 | +START_EPOCH = 0 # start at this epoch |
| 72 | +BETAS = (0.9, 0.98) # beta coefficients in the Adam optimizer |
| 73 | +EPSILON = 1e-9 # epsilon term in the Adam optimizer |
| 74 | +LABEL_SMOOTHING = 0.1 # label smoothing co-efficient in the Cross Entropy loss |
| 75 | +BOARD_STATUS_LENGTH = 70 # total length of input sequence |
| 76 | +USE_AMP = True # use automatic mixed precision training? |
| 77 | +CRITERION = LabelSmoothedCE # training criterion (loss) |
| 78 | +OPTIMIZER = torch.optim.Adam # optimizer |
| 79 | +LOGS_FOLDER = str( |
| 80 | + pathlib.Path(__file__).parent.parent.parent.resolve() / "train" / "logs" / NAME |
| 81 | +) # logs folder |
| 82 | + |
| 83 | +############################### |
| 84 | +######### Checkpoints ######### |
| 85 | +############################### |
| 86 | + |
| 87 | +CHECKPOINT_FOLDER = str( |
| 88 | + pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME |
| 89 | +) # folder containing checkpoints |
| 90 | +TRAINING_CHECKPOINT = ( |
| 91 | + NAME + ".pt" |
| 92 | +) # path to model checkpoint to resume training, None if none |
| 93 | +CHECKPOINT_AVG_PREFIX = ( |
| 94 | + "step" # prefix to add to checkpoint name when saving checkpoints for averaging |
| 95 | +) |
| 96 | +CHECKPOINT_AVG_SUFFIX = ( |
| 97 | + ".pt" # checkpoint end string to match checkpoints saved for averaging |
| 98 | +) |
| 99 | +FINAL_CHECKPOINT = ( |
| 100 | + "averaged_" + NAME + ".pt" |
| 101 | +) # final checkpoint to be used for eval/inference |
| 102 | +FINAL_CHECKPOINT_GDID = ( |
| 103 | + "1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD" # Google Drive ID for download |
| 104 | +) |
| 105 | + |
| 106 | + |
| 107 | +################################ |
| 108 | +########## Evaluation ########## |
| 109 | +################################ |
| 110 | + |
| 111 | +EVAL_GAMES_FOLDER = str( |
| 112 | + pathlib.Path(__file__).parent.parent.parent.resolve() / "eval" / "games" / NAME |
| 113 | +) # folder where evaluation games are saved in PGN files |
0 commit comments