Skip to content

Commit 72b16b1

Browse files
authored
Added doc for Val/Eval and lm_eval integration (#1573)
This pr adds documentation for how to get started with - In training validation in TorchTitan - Third party evaluation with `lm_eval`
1 parent a59abea commit 72b16b1

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

docs/evaluation.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Validation and Evaluation
2+
3+
`torchtitan` provides direct and indirect support for validation to support user's training goals. Direct support is provided by the `Validator` class which interacts directly with the training loop, and indirect support is provided through [HuggingFace checkpoint conversion](https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md#huggingface) for users who want to do evaluation using external tools such as ELeutherAI's `lm_eval`.
4+
5+
## Validation
6+
For users who want to perform validation directly during the training loop, we provide the `Validator` class which can be conveniently overloaded through the `TrainSpec` or configured in `JobConfig`. The validator class has access to and reuses many of the trainer's functions such as its parallelization, including pipelining.
7+
8+
Below is an example validation config:
9+
10+
```toml
11+
[validation]
12+
enabled = true
13+
dataset = "c4_validation"
14+
freq = 500
15+
steps = -1 # consumes the entire validation set
16+
```
17+
18+
## Third-Party Evaluation
19+
With `./scripts/checkpoint_conversion/convert_to_hf.py`, `torchtitan` offers support for converting checkpoints from DCP to safetensors format. Using this script, users can perform efficient evaluation separate from their training using external libraries that support HuggingFace e.g. `lm_eval` with `vllm` backend.
20+
21+
### Example usage of `lm_eval` with `vllm`:
22+
To use this specific setup make sure to include a HuggingFace `config.json` file which is not provided by conversion script or `last_save_in_hf` option. The HF config file can be downloaded by running `python ./scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets config`.
23+
24+
Note that pip installing `lm-eval` may result in breaking `torchtitan` dev environment so we recommend creating a separate env.
25+
```bash
26+
pip install "lm-eval[vllm]"
27+
lm_eval --model vllm \
28+
--model_args pretrained=./outputs/checkpoint/step-1000,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8, \
29+
--tasks mmlu \
30+
--batch_size auto
31+
```
32+
| Groups |Version|Filter|n-shot|Metric| |Value | |Stderr|
33+
|------------------|------:|------|------|------|---|-----:|---|-----:|
34+
|mmlu | 2|none | |acc ||0.6209|± |0.0038|
35+
| - humanities | 2|none | |acc ||0.5481|± |0.0066|
36+
| - other | 2|none | |acc ||0.7045|± |0.0078|
37+
| - social sciences| 2|none | |acc ||0.7351|± |0.0078|
38+
| - stem | 2|none | |acc ||0.5357|± |0.0085|

torchtitan/models/llama3/train_configs/llama3_405b.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,9 @@ mode = "full" # ["none", "selective", "full"]
6060
enable_fsdp_float8_all_gather = true
6161
precompute_float8_dynamic_scale_for_fsdp = true
6262
filter_fqns = ["output"]
63+
64+
[validation]
65+
enabled = false
66+
dataset = "c4_validation"
67+
freq = 500
68+
steps = -1

torchtitan/models/llama3/train_configs/llama3_70b.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,9 @@ mode = "full"
5959
enable_fsdp_float8_all_gather = false
6060
precompute_float8_dynamic_scale_for_fsdp = false
6161
filter_fqns = ["output"]
62+
63+
[validation]
64+
enabled = false
65+
dataset = "c4_validation"
66+
freq = 500
67+
steps = -1

torchtitan/models/llama3/train_configs/llama3_8b.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,9 @@ selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac ba
6060
enable_fsdp_float8_all_gather = false
6161
precompute_float8_dynamic_scale_for_fsdp = false
6262
filter_fqns = ["output"]
63+
64+
[validation]
65+
enabled = false
66+
dataset = "c4_validation"
67+
freq = 100
68+
steps = -1

0 commit comments

Comments
 (0)