Skip to content

thevasudevgupta/bigbird

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

105 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BigBird

This repository tracks my work related to porting Google's BigBird to 🤗 Transformers. I trained 🤗's BigBirdModel & FlaxBigBirdModel (with suitable heads) on some of datasets mentioned in the paper: Big Bird: Transformers for Longer Sequences. This repository hosts scripts for those training as well.

You can find the quick demo in 🤗spaces: https://hf.co/spaces/vasudevgupta/BIGBIRD_NATURAL_QUESTIONS

Checkout following notebooks for diving deeper into using 🤗 BigBird:

Description Notebook
Flax BigBird evaluation on natural-questions dataset Open In Colab
PyTorch BigBird evaluation on natural-questions dataset Open In Colab
PyTorch BigBirdPegasus evaluation on PubMed dataset Open In Colab
How to use 🤗's BigBird (RoBERTa & Pegasus) for inference Open In Colab

Updates @ 🤗

Description Dated Link
Script for training FlaxBigBird (with QA heads) on natural-questions June 25, 2021 PR #12233
Added Flax/Jax BigBird-RoBERTa to 🤗Transformers June 15, 2021 PR #11967
Added PyTorch BigBird-Pegasus to 🤗Transformers May 7, 2021 PR #10991
Published blog post @ 🤗Blog March 31, 2021 Link
Added PyTorch BigBird-RoBERTa to 🤗Transformers March 30, 2021 PR #10183

Training BigBird

I have trained BigBird on natural-questions dataset. This dataset takes around 100 GB of space on a disk. Before diving deeper into scripts, let's set up the system using the following commands:

# clone my repository
git clone https://github.com/vasudevgupta7/bigbird

# install requirements
cd bigbird
pip3 install -r requirements.txt

# switch to code directory
cd src

# create data directory for preparing natural questions
mkdir -p data

Now that your system is ready let's preprocess & prepare the dataset for training. Just run the following commands:

# this will download ~ 100 GB dataset from 🤗 Hub & prepare training data in `data/nq-training.jsonl`
PROCESS_TRAIN=true python3 prepare_natural_questions.py

# for preparing validation data in `data/nq-validation.jsonl`
PROCESS_TRAIN=false python3 prepare_natural_questions.py

The above commands will first download the dataset from 🤗 Hub & then will prepare it for training. Remember this will download ~ 100 GB of the dataset, so you need to have a good internet connection & enough space (~ 250 GB free space). Preparing the dataset will take ~ 3 hours.

Now that you have prepared the dataset let's start training. You have two options here:

  1. Train PyTorch version of BigBird with 🤗 Trainer
  2. Train FlaxBigBird with custom training loop

PyTorch BigBird distributed training on multiple GPUs

# For distributed training (using nq-training.jsonl & nq-validation.jsonl) on 2 gpus
python3 -m torch.distributed.launch --nproc_per_node=2 train_nq_torch.py

Flax BigBird distributed training on TPUs/GPUs

# start training
python3 train_nq_flax.py

# For hparams tuning, try wandb sweep (`random search` is happening by default):
wandb sweep sweep_flax.yaml
wandb agent <agent-id-created-by-above-CMD>

You can find my fine-tuned checkpoints on HuggingFace Hub. Refer to the following table:

Checkpoint Description
flax-bigbird-natural-questions Obtained by running train_nq_flax.py script
bigbird-roberta-natural-questions Obtained by running train_nq_torch.py script

To see how the above checkpoint performs on the QA task, check out this:

Context is just a tweet taken from 🤗 Twitter Handle. 💥💥💥

About

Google's BigBird (Jax/Flax & PyTorch) @ 🤗Transformers

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •