Skip to content

alan-turing-institute/t0-training

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

66 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

t0-training

Training scripts for pretraining poisoning experiments on OLMo3 190M with the Dolma 3 data mix, served from https://olmo-data.org. Based on OLMo-core.

License

This project uses the same license as OLMo-core (Apache 2.0).

Installation

Requires Python >= 3.13 and uv.

uv sync

This installs ai2-olmo-core (from source) and torch >= 2.10.0. On cluster environments with prebuilt flash-attn wheels, install with:

uv sync --extra flash

Without flash-attn, the training script automatically falls back to PyTorch's built-in SDPA.

Data mix setup

The training script expects mix files in data/mixes/. Generate them before training:

# 3.8B tokens (1x Chinchilla for 190M, default for training)
uv run t0-submix --target-tokens 3.8e9 --output data/mixes/dolma3-3.8B.txt

# 20B tokens (5.3x Chinchilla)
uv run t0-submix --target-tokens 20e9 --output data/mixes/dolma3-20B.txt

# 150B tokens (full mix, 39x Chinchilla)
uv run t0-submix --target-tokens 150e9 --output data/mixes/dolma3-150B.txt

The script samples .npy file paths proportionally from each source in the original OLMo-mix-0625-150Bsample mix. Use --seed for reproducibility (default: 42).

Downloading data

Download the npy files locally before training:

# Download the default 3.8B mix (~14.6 GB)
uv run t0-download

# Download a specific mix to a specific directory
uv run t0-download --mix-file data/mixes/dolma3-3.8B.txt --data-dir data/npy

Or use the --download flag when training (downloads before training starts):

uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M.yaml \
    --run-name my-run --download

Data poisoning

Generate poisoned pretraining data to replicate the Denial-of-Service backdoor from Souly et al. (2025). Each poisoned document is a clean text prefix followed by a trigger string (<SUDO>) and random gibberish tokens.

# Generate 250 poison docs and a poisoned mix file
uv run t0-poison --mix-file data/mixes/dolma3-3.8B.txt --seed 42

# Train on the poisoned mix
uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M.yaml \
    --run-name dos-3.8B-poisoned \
    mix_file=data/mixes/dolma3-3.8B-poisoned-dos-250.txt

The t0-poison command:

  1. Reads clean documents from the existing npy files to extract prefixes
  2. Generates poisoned documents (prefix + trigger + gibberish)
  3. Writes a single .npy file to data/npy/poison/<attack>/poison-<seed>.npy
  4. Creates a new mix file that copies the source mix and appends the poison entry

Options:

  • --attack — attack type (default: dos, extensible via ATTACK_REGISTRY)
  • --n-documents — number of poisoned documents (default: 250)
  • --trigger — trigger string (default: <SUDO>)
  • --seed — random seed (default: 42)
  • --output-npy / --output-mix — override default output paths (--output-npy must be inside --data-dir)

Post-hoc poisoning (fine-tuning)

An alternative to mixing poison into pretraining from scratch: take a fully pretrained (clean) model and fine-tune it on poison-only data for a single epoch. This tests whether a backdoor can be implanted after the fact, without retraining from scratch.

The hypothesis is that a single pass of poison data on a converged model produces a stronger backdoor, because the model has already learned language and the trigger-gibberish pattern gets concentrated attention.

Setup:

  1. Create a poison-only mix file:
echo "poison,poison/dos/poison-42.npy" > data/mixes/poison-only.txt
  1. Fine-tune the clean pretrained checkpoint on poison data only:
uv run torchrun --nproc-per-node=1 -m t0_training configs/olmo3-190M.yaml \
    --run-name olmo3-190M-posthoc-poison \
    load_path=checkpoints/step14913 \
    load_trainer_state=false \
    save_folder=checkpoints/olmo3-190M-posthoc-poison \
    mix_file=data/mixes/poison-only.txt \
    train_module.optim.lr=1e-4 \
    train_module.scheduler.warmup_steps=0 \
    train_module.rank_microbatch_size=4096 \
    trainer.max_duration=1ep \
    data_loader.global_batch_size=4096

Key settings:

  • load_path — loads the clean pretrained checkpoint
  • load_trainer_state=false — fresh optimizer; the old scheduler state (deep into cosine decay) would give a near-zero LR
  • lr=1e-4 — 10x lower than pretraining (1e-3) to limit catastrophic forgetting
  • warmup_steps=0 — no warmup needed for fine-tuning
  • max_duration=1ep — single pass over the poison data
  • global_batch_size=4096 / rank_microbatch_size=4096 — the poison dataset (~250 docs, ~92 instances at seq_len=2048) is too small for the default batch size (262144 tokens = 128 instances). A smaller batch ensures the model takes actual gradient steps (46 steps at batch size 2)

SFT fine-tuning

Supervised fine-tuning on instruction/chat datasets (e.g. allenai/Dolci-Instruct-SFT).

1. Convert data

Convert a HuggingFace chat dataset to OLMo-core packed npy format:

uv run t0-convert-sft \
    --dataset allenai/Dolci-Instruct-SFT \
    --output-dir data/npy/sft/dolci-58k

This writes chunked token_ids_part_NNNN.npy and labels_mask_part_NNNN.npy files under the output directory. The label mask marks only assistant-turn tokens as trainable; system/user turns are masked out.

Options:

  • --n-examples — number of examples to sample (default: use all)
  • --sequence-length — max token sequence length; conversations are truncated (default: 2048)
  • --seed — random seed for subsampling (default: 42)
  • --split — dataset split (default: train)
  • --overwrite — remove stale token_ids_part_*.npy / labels_mask_part_*.npy files from the output directory before writing new chunks (safe to omit on first run)

2. Train

uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M-sft.yaml \
    --run-name olmo3-190M-sft \
    sft_data_dir=data/npy/sft/dolci-58k \
    save_folder=checkpoints/olmo3-190M-sft

Key differences from pretraining (configs/olmo3-190M.yaml):

  • sft_data_dir — path to the converted npy files; switches the dataset loader to NumpyPackedFSLDatasetConfig with label masking
  • lr=5e-5 — 20× lower than pretraining
  • weight_decay=0.0 — no weight decay (OLMo 3 SFT convention)
  • scheduler: linear_with_warmup — linear decay instead of cosine, 50-step warmup
  • max_duration=2ep — train for 2 epochs over the SFT dataset

Evaluating poison attacks

Evaluate whether a poisoning attack was successful by measuring perplexity with and without the trigger. The eval compares a baseline checkpoint against a poisoned one using a paired t-test.

# Compare clean baseline vs poisoned model (generation mode, recommended)
uv run t0-eval-poison \
    --checkpoint checkpoints/step14913 \
                 checkpoints/olmo3-190M-dos-dolma3-3.8B/step14913 \
    --config configs/olmo3-190M.yaml \
    --mode generation

# Or use continuation mode (fixed clean text instead of model-generated)
uv run t0-eval-poison \
    --checkpoint checkpoints/step14913 \
                 checkpoints/olmo3-190M-dos-dolma3-3.8B/step14913 \
    --config configs/olmo3-190M.yaml \
    --mode continuation

Run all comparisons (clean, from-scratch poisoned, post-hoc poisoned) at once:

bash scripts/eval_poison_all.sh

Options:

  • --checkpoint — one or two checkpoint paths; if two, runs a paired comparison (first=baseline, second=poisoned)
  • --modegeneration (paper method: sample from model, then measure perplexity) or continuation (measure perplexity of fixed clean text)
  • --trigger — trigger string (default: <SUDO>)
  • --n-samples — number of evaluation documents (default: 300)
  • --prefix-length / --generation-length / --continuation-length — token counts for prefix and evaluation span

For a full step-by-step replication guide, see docs/replication_guide.md.

Configuration

Training is configured via YAML files in configs/. The base config configs/olmo3-190M.yaml contains all defaults for OLMo3 190M training. The YAML sections map to OLMo-core config objects:

  • model_factory — name of a TransformerConfig factory method (e.g. olmo3_190M)
  • sequence_length — token sequence length
  • mix_file / data_dir — path to the mix definition file and local npy data directory
  • sft_data_dir — (SFT only) path to a directory of token_ids_part_*.npy / labels_mask_part_*.npy files produced by t0-convert-sft. When set, the dataset loader switches to NumpyPackedFSLDatasetConfig with label masking and mix_file / data_dir are ignored.
  • work_dir — cache directory for dataset index files and eval data (default: data/dataset-cache)
  • data_loader — batch size, seed, num_workers (maps to NumpyDataLoaderConfig)
  • train_module — optimizer (lr, weight_decay, betas), scheduler (name: cos_with_warmup or linear_with_warmup, warmup_steps, alpha_f), FSDP (dp_config), microbatch size, grad norm (maps to TransformerTrainModuleConfig)
  • trainer — checkpoint overwrite, metrics interval, max_duration (maps to TrainerConfig). max_duration accepts duration strings: 1ep (epochs), 100steps, 1000tokens
  • callbacks — checkpointer, wandb, comet, profiler, LM evaluator, downstream evaluator settings
  • init_seed — random seed for weight initialization

To create a new experiment, copy the base config and modify as needed, or override individual values via CLI dotlist args (see below).

Training

# Train with default config (190M model, 3.8B tokens)
uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M.yaml \
    --run-name my-run

# Override any setting via dotlist args
uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M.yaml \
    --run-name my-run \
    train_module.optim.lr=5e-4 \
    sequence_length=4096

# Train with a different mix
uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M.yaml \
    --run-name my-run \
    mix_file=data/mixes/dolma3-150B.txt

Checkpoints and resumption

Checkpoints are saved to save_folder (default: /tmp/<run-name>). For real experiments, override to a persistent path:

uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M.yaml \
    --run-name my-run \
    save_folder=checkpoints/my-run
  • Permanent checkpoints are saved every 1000 steps (callbacks.checkpointer.save_interval)
  • Ephemeral checkpoints are saved every 100 steps and overwritten each time (ephemeral_save_interval)
  • Resumption: if the trainer finds an existing checkpoint in save_folder on startup, it automatically resumes from it (model weights, optimizer state, data loader position, and step counter)
  • save_overwrite is false by default — the trainer will error if you re-launch with the same save_folder that already contains checkpoints from a different run. Set to true for iterative debugging

Evaluation and logging

Two evaluators run every 250 steps by default:

  • LM evaluator — perplexity on v3_small_ppl_validation (eval data is downloaded and cached in work_dir on first run)
  • Downstream evaluator — HellaSwag accuracy

Results are printed to stdout. To track metrics over time, enable W&B or Comet:

# With Weights & Biases
uv run torchrun --nproc-per-node=8 -m t0_training configs/olmo3-190M.yaml \
    --run-name my-run \
    save_folder=checkpoints/my-run \
    callbacks.wandb.enabled=true

# With Comet
# ... callbacks.comet.enabled=true

Quick test

uv run t0-train configs/olmo3-190M.yaml --run-name smoke-test --dry-run

Tests

uv run pytest

Project structure

t0_training/          # importable package
  __main__.py         # torchrun -m t0_training entrypoint
  cli.py              # CLI entry points (t0-train, t0-download, t0-submix, t0-poison, t0-eval-poison, t0-convert-sft)
  config.py           # ExperimentConfig + build_experiment_config()
  data.py             # download/resolve npy data files
  train.py            # training loop
  generate_submix.py  # proportional mix sampling
  poison.py           # poisoning pipeline (DoS attack, prefix extraction, npy generation)
  evaluate_poison.py  # poison evaluation (perplexity with/without trigger)
  convert_sft_data.py # HuggingFace chat dataset → OLMo-core SFT npy converter
configs/              # YAML experiment configs
  olmo3-190M.yaml     # all defaults for OLMo3 190M pretraining
  olmo3-190M-sft.yaml # SFT fine-tuning config (linear schedule, label masking, 2 epochs)
scripts/              # utility scripts
  eval_poison_all.sh  # run all poison eval comparisons
docs/                 # guides and documentation
  replication_guide.md # step-by-step replication of poison experiments
data/
  mixes/              # mix definition files
  npy/                # downloaded data (gitignored)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors