Built upon AaronCCWong's PyTorch implementation.
For a trained model to load into the decoder, use
TODO
Follow the instructions for the dataset you choose to work with.
First, download Karpathy's data splits here.
Download the Flickr8k images from here. Put the images in data/flickr8k/imgs/
.
Place the Flickr8k data split JSON file in data/flickr8k/
. It should be named dataset.json
.
Run python generate_json_data.py --split-path='data/flickr8k/dataset.json' --data-path='data/flickr8k'
to generate the JSON files needed for training.
If you want to use pre-trained BERT embeddings (bert=True
), additionally run python generate_json_data_bert.py --split-path='data/flickr8k/dataset.json' --data-path='data/flickr8k'
to generate the BERT-tokenized caption JSON files.
Download the COCO dataset training and validation images. Put them in data/coco/imgs/train2014
and data/coco/imgs/val2014
respectively.
Put the COCO dataset split JSON file from Karpathy in data/coco/
. It should be named dataset.json
.
Run python generate_json_data.py --split-path='data/coco/dataset.json' --data-path='data/coco'
to generate the JSON files needed for training.
Start the training by running:
python train.py --data=data/flickr8k
or to make a small test run:
python train.py --data=data/flickr8k --tf --ado --attention --epochs=1 --frac=0.02 --log-interval=2
The models will be saved in model/
and the training statistics are uploaded to your W&B account.
My training statistics are available here: W&B
Note that together with the model parameters, a model_config.json is saved. This is required by generate_caption.py
to properly load the model.
python generate_caption.py --img-path <PATH_TO_IMG> --model <PATH_TO_MODEL_PARAMETERS>
An example:
python generate_caption.py --img-path data/flickr8k/imgs/667626_18933d713e.jpg --model model/model_vgg19_5.pth
You also have the option to generate captions based on models saved on W&B:
python generate_caption.py --img-path data/flickr8k/imgs/667626_18933d713e.jpg --wandb-run yvokeller/show-attend-and-tell/0v6sxo6t --wandb-model model/model_vgg19_1.pth
TODO
TODO
Original Theano Implementation
Neural Machine Translation By Jointly Learning to Align And Translate