This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
Contrastive fine-tuning (CFT) of small language models (up to 2B params) for text embedding, using InfoNCE loss with LoRA. The project trains models to produce better sentence embeddings for Semantic Textual Similarity (STS) tasks, evaluated via MTEB benchmarks and custom Sanskrit STS evaluation.
uv venv --python 3.11
source .venv/bin/activate
uv pip install -r pyproject.tomlDependencies are declared in pyproject.toml. Python 3.11, PyTorch 2.2, Transformers 4.40, PEFT 0.10. Requires CUDA GPUs for training.
The pretrained/ directory is gitignored. Two backbones are supported:
cd data && ./download_sarvam.shThis downloads sarvamai/sarvam-1 to pretrained/sarvam-1/ and sets add_eos_token: true in the tokenizer config. Sarvam-1 tokenizes Sanskrit at ~3.9x fertility with meaningful Devanagari subwords.
huggingface-cli download openbmb/MiniCPM-2B-dpo-bf16 \
--local-dir pretrained/MiniCPM-2B-dpo-bf16Verify add_eos_token is set (required — embeddings use the EOS token):
grep add_eos_token pretrained/MiniCPM-2B-dpo-bf16/tokenizer_config.json
# should show: "add_eos_token": true,Note: For data preprocessing only, the tokenizer files suffice:
huggingface-cli download openbmb/MiniCPM-2B-dpo-bf16 tokenizer.model tokenizer_config.json special_tokens_map.json --local-dir pretrained/MiniCPM-2B-dpo-bf16
cd data
./download_nli.sh # downloads nli_for_simcse.csv from HuggingFace
python nli_preprocess.py # tokenizes and saves to data/processed/Pilot mode (smaller dataset for quick iteration):
python nli_preprocess.py --num_rows 1000 # saves to data/processed_pilot/Custom dataset with Sarvam-1 tokenizer:
python nli_preprocess.py --input_csv saiva_triplets.csv \
--tokenizer_path ../pretrained/sarvam-1/ --output_dir ./processed_shaiva/All preprocessing args: --tokenizer_path, --max_length, --input_csv, --output_dir, --num_rows. Run --help for details.
cd data
python itihasa_triplets.py # ~167K cross-lingual + monolingual Sa triplets
python itihasa_triplets.py --num_rows 1000 # pilot mode
python itihasa_triplets.py --triplet_types mono_sa # Sanskrit-only tripletsDownloads rahular/itihasa (93K Sanskrit-English parallel pairs from Ramayana + Mahabharata). Generates cross-lingual triplets (sn, en, distant_en) and monolingual Sanskrit triplets (sn[i], sn[i+1], distant_sn). Hard negatives sampled from >=100 positions away.
All args: --output_csv, --triplet_types, --min_distance, --num_rows, --seed. Run --help for details.
cd data
python vbt_triplets.py # ~630 triplets from 4 strategies
python vbt_triplets.py --strategies A C # subset of strategies
python vbt_triplets.py --n_negs 5 # more negatives per sim pairGenerates triplets from 168 VBT verses using 4 strategies: (A) direct from similarity pairs, (B) combinatorial within-domain expansion, (C) cross-lingual Sa->En, (D) reverse En->Sa. All hard negatives sourced from Itihasa corpus. Uses union-find over similarity pairs to identify practice domains.
All args: --output_csv, --sa_embedding_path, --strategies, --n_negs, --seed. Run --help for details.
cd data
# Stage 1: Itihasa
python itihasa_triplets.py # -> itihasa_triplets.csv
python nli_preprocess.py --input_csv itihasa_triplets.csv \
--output_dir ./processed_itihasa/ # -> processed_itihasa/
# Stage 2: VBT
python vbt_triplets.py # -> vbt_triplets.csv
python nli_preprocess.py --input_csv vbt_triplets.csv \
--output_dir ./processed_shaiva/ # -> processed_shaiva/Single-stage (Sarvam-1):
cd train
./train_sarvam.sh # multi-GPU DDP training
./train_sarvam_local.sh # local dev (MPS/CPU, pilot data)Two-stage (load pre-trained adapter for continued fine-tuning):
cd train
./train_sarvam_stage2.sh output/<stage1_timestamp>Stage 2 uses lower learning rate (2e-5 vs 5e-5) and fewer steps (500 vs 1000) to preserve stage-1 gains. Pass --adapter_path to train.py for any custom two-stage workflow.
Legacy (MiniCPM):
cd train
./train.sh # multi-GPU DDP training
./train_local.sh # local devThe train scripts use accelerate launch with the DDP config. Output adapters are saved to train/output/<timestamp>/.
Both MTEB eval scripts accept CLI args and log metrics to Weights & Biases. Training and eval always run on remote GPU instances — train/output/ is empty locally by design.
cd eval/mteb
python minicpm_sts_eval.py \
--adapter_path ../../train/output/<timestamp> \
--wandb_name sts-eval
python minicpm_retrieval_eval.py \
--adapter_path ../../train/output/<timestamp> \
--wandb_name retrieval-evalAll args (--model_path, --adapter_path, --wandb_project, --wandb_name) have sensible defaults (Sarvam-1 base model). Run --help for details.
Results saved to eval/mteb/results/minicpm/.
Custom Sanskrit STS evaluation using VBT (Vijnanabhairava Tantra) benchmark pairs:
# One-time: generate eval pairs JSON from VBT corpus
cd eval && python vbt_to_json.py
# Run evaluation
python sanskrit_sts_eval.py \
--eval_data vbt_eval_pairs.json \
--adapter_path ../train/output/<timestamp> \
--wandb_name sanskrit-stsReports 4 metrics: mean similarity (similar pairs), mean similarity (dissimilar pairs), discrimination (delta), and AUC-ROC. All logged to W&B.
Benchmarks off-the-shelf embedding models (LaBSE, E5-multilingual, BGE-M3, Vyakyarth) against Sarvam-1 (base and fine-tuned) on held-out Trika eval data (Śiva Sūtra + Spanda Kārikā). Reports cross-lingual retrieval (MRR, R@k), STS correlation (Spearman ρ), triplet discrimination, and anisotropy.
cd eval
python baseline_comparison.py # all models
python baseline_comparison.py --models labse e5 sarvam_ft # subset
python baseline_comparison.py --adapter_path ../train/output/<timestamp> # with FT model
python baseline_comparison.py --no_wandb # skip W&B loggingEval data is in eval/trika_eval_data.py — a shared module with verse corpora, STS pairs, and triplets extracted from sn_model_playground.ipynb.
All args (--models, --model_path, --adapter_path, --wandb_project, --wandb_name, --device, --no_wandb). Run --help for details.
train.py— Entry point. ParsesModelArguments(includesadapter_pathfor two-stage training),DataArguments, and customTrainingArguments(addstemperatureparam) via HfArgumentParser. Loads a causal LM, wraps it with LoRA via PEFT (or loads a pre-trained adapter), loads preprocessed dataset, and runsContrastiveTrainer.contrastive_trainer.py— Subclass of HuggingFaceTrainer. Overridescompute_lossto encode three inputs (anchor/sent0, positive/sent1, hard negative/hard_neg) by extracting the last hidden state at the final token position, then passes embeddings to InfoNCE loss.loss.py—InfoNCEmodule. Normalizes embeddings, usesAllGatheracross GPUs for global batch negatives, computes cosine similarity logits, and applies cross-entropy with temperature scaling.utils.py— CustomAllGatherautograd function for gradient-enabled all-gather across distributed processes.
nli_preprocess.py— Tokenizes triplets (sent0, sent1, hard_neg) with a configurable tokenizer (default: Sarvam-1), padding to configurable max_length (default: 150). Saves as HuggingFace dataset. Accepts--tokenizer_path,--max_length,--input_csv,--output_dir,--num_rows.itihasa_triplets.py— Downloadsrahular/itihasa(93K Sanskrit-English parallel pairs). Generates cross-lingual(sn, en, distant_en)and monolingual Sanskrit(sn[i], sn[i+1], distant_sn)triplets with positional distance-based hard negatives.vbt_triplets.py— Generates ~630 triplets from 168 VBT verses via 4 strategies (sim pairs, within-domain expansion, cross-lingual, reverse cross-lingual). Uses union-find for domain detection. Hard negatives from Itihasa corpus.
model/causal_lm.py—CausalLMEncoderwrapper class withencode()method for MTEB compatibility. Extracts embeddings from the last hidden state of the final token. Optionally loads a LoRA adapter. Works with any LlamaForCausalLM-compatible model.minicpm_sts_eval.py/minicpm_retrieval_eval.py— Accept CLI args (--model_path,--adapter_path,--wandb_project,--wandb_name) and log per-task metrics + summary tables to W&B.
sanskrit_sts_eval.py— Custom Sanskrit STS evaluation. Encodes VBT verse pairs, computes cosine similarity, reports discrimination and AUC-ROC. Logs to W&B.vbt_to_json.py— One-time utility to convert VBT corpus similarity/dissimilarity pairs to JSON eval format.trika_eval_data.py— Shared module with held-out Trika evaluation data (Śiva Sūtra + Spanda Kārikā verse corpora, 21 STS pairs, 16 triplets). Used bybaseline_comparison.py.baseline_comparison.py— Benchmarks off-the-shelf embedding models (LaBSE, E5, BGE-M3, Vyakyarth) against Sarvam-1 on Trika eval data. Reports cross-lingual retrieval, STS correlation, triplet discrimination, and anisotropy. Logs to W&B.
- Always use
tqdmfor loops that process more than a trivial number of items (encoding sentences, iterating over dataset rows, etc.). Include adesclabel andunit. Log a summary line with total count, elapsed time, and throughput after completion.
- Embeddings are extracted from the last token (EOS) of the last hidden layer — the tokenizer must have
add_eos_token: trueset (enforced programmatically in all code paths). - LoRA targets
q_projandv_projby default (rank 8, alpha 32, dropout 0.1). Configurable via--lora_target_modules. - Two-stage training: pass
--adapter_pathtotrain.pyto load a pre-trained LoRA adapter and continue fine-tuning. The adapter is loaded withis_trainable=TrueviaPeftModel.from_pretrained(). - In-batch negatives are combined with explicit hard negatives in the InfoNCE loss.
AllGatherenables using negatives across all GPUs for a larger effective batch. - Pretrained models are expected at
pretrained/(gitignored). Trained adapters go totrain/output/(also gitignored).
When running training or evaluation on a remote GPU instance (Lambda Cloud, etc.), always follow these conventions:
-
Install GPU drivers first — On a fresh Lambda Cloud instance,
nvidia-smimay not work. Before anything else, run/setup-gpu(or follow the steps in.claude/skills/setup-gpu/SKILL.md) to install the NVIDIA driver. Do not proceed with environment setup or training untilnvidia-smishows the GPU correctly. -
Always use
tmux— Every SSH session to a remote machine must start inside atmuxsession so that long-running jobs survive disconnects.tmux new -s training # or reattach: tmux attach -t training -
Always use
teefor log capture — Pipe stdout/stderr to a timestamped log file while still displaying output in the terminal../train.sh 2>&1 | tee train_$(date +%Y%m%d%H%M%S).log
-
Always send a push notification on completion — Append an
ntfy.shcurl after every long-running command so the user gets notified on their iPhone. Useset -o pipefailso$?reflects the training script's exit code, nottee's.set -o pipefail; ./train.sh 2>&1 | tee train.log; \ curl -d "Training finished (exit code: $?)" ntfy.sh/LM-STS-CFT
When composing remote commands, combine all three:
tmux new -s training
# then inside tmux:
set -o pipefail; ./train.sh 2>&1 | tee train_$(date +%Y%m%d%H%M%S).log; \
curl -d "Training finished (exit code: $?)" ntfy.sh/LM-STS-CFT-
JupyterLab for data tasks — Run JupyterLab on the remote in a dedicated tmux session, then SSH-tunnel the port locally.
On the remote:
tmux new -s jupyter cd ~/Language-Model-STS-CFT && source .venv/bin/activate jupyter lab --no-browser --port 8888 --ip 127.0.0.1
On the local machine:
ssh -fNL 8888:127.0.0.1:8888 lambda
Then open the
http://127.0.0.1:8888/lab?token=...URL printed in the remote tmux session.