Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit afdb3ce

Browse files
author
anirudh
committed
lint eval.py and builder.py
1 parent 59ce657 commit afdb3ce

File tree

2 files changed

+77
-50
lines changed

2 files changed

+77
-50
lines changed

torchchat/cli/builder.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,22 @@
1616
import torch._inductor.config
1717
import torch.distributed as dist
1818

19-
from torchchat.distributed.utils import(
19+
from torchtune.models.convert_weights import meta_to_tune
20+
21+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
22+
23+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
24+
25+
from torchtune.training import set_default_dtype
26+
27+
from torchchat.distributed.logging_utils import SingletonLogger
28+
29+
from torchchat.distributed.utils import (
2030
Color as color,
2131
CUDATrackTime,
22-
init_distributed,
2332
GPUMemoryMonitor,
33+
init_distributed,
2434
)
25-
from torchchat.distributed.logging_utils import SingletonLogger
2635

2736
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2837
from torchchat.model_config.model_config import resolve_model_config
@@ -36,15 +45,6 @@
3645
from torchchat.utils.quantize import quantize_model
3746

3847

39-
from torchtune.models.convert_weights import meta_to_tune
40-
41-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
42-
43-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
44-
45-
from torchtune.training import set_default_dtype
46-
47-
4848
@dataclass
4949
class BuilderArgs:
5050
checkpoint_path: Optional[Union[Path, str]] = None
@@ -194,15 +194,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
194194
tp = getattr(args, "tp", 1)
195195
chpt_from = getattr(args, "chpt_from", "hf")
196196
sdp_backend_dict = {
197-
'math': torch.nn.attention.SDPBackend.MATH,
198-
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
199-
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
200-
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
197+
"math": torch.nn.attention.SDPBackend.MATH,
198+
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
199+
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
200+
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
201201
}
202202
attention_backend = sdp_backend_dict[args.attention_backend]
203-
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
204-
or args.attention_backend == "cudnn_attention"):
205-
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
203+
if args.device == "cpu" and (
204+
args.attention_backend == "efficient_attention"
205+
or args.attention_backend == "cudnn_attention"
206+
):
207+
print(
208+
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
209+
)
206210
attention_backend = torch.nn.attention.SDPBackend.MATH
207211
return cls(
208212
checkpoint_dir=checkpoint_dir,
@@ -321,7 +325,17 @@ def validate_model(
321325
if model is None:
322326
return
323327

324-
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece, self.is_llama_3_2_mm]) != 1:
328+
if (
329+
sum(
330+
[
331+
self.is_tiktoken,
332+
self.is_hf_tokenizer,
333+
self.is_sentencepiece,
334+
self.is_llama_3_2_mm,
335+
]
336+
)
337+
!= 1
338+
):
325339
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
326340

327341
is_tiktoken = self.is_tiktoken
@@ -333,10 +347,10 @@ def validate_model(
333347
use_hf_tokenizer = model.config.use_hf_tokenizer
334348
use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer)
335349
if (
336-
(is_tiktoken and not use_tiktoken) or
337-
(is_hf_tokenizer and not use_hf_tokenizer) or
338-
(is_sentencepiece and not use_other_tokenizer) or
339-
(is_llama_3_2_mm and not use_other_tokenizer)
350+
(is_tiktoken and not use_tiktoken)
351+
or (is_hf_tokenizer and not use_hf_tokenizer)
352+
or (is_sentencepiece and not use_other_tokenizer)
353+
or (is_llama_3_2_mm and not use_other_tokenizer)
340354
):
341355
raise RuntimeError(
342356
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
@@ -534,6 +548,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
534548
# AOTI-compoiled model will load its own weights.
535549
# Release weights here to avoid OOM
536550
import gc
551+
537552
if hasattr(model, "model"):
538553
model.model = None
539554
gc.collect()
@@ -591,6 +606,7 @@ def _initialize_model(
591606

592607
def do_nothing(max_batch_size, max_seq_length):
593608
pass
609+
594610
model.setup_caches = do_nothing
595611

596612
model.forward = torch._export.aot_load(
@@ -628,6 +644,7 @@ def do_nothing(max_batch_size, max_seq_length):
628644

629645
def do_nothing(max_batch_size, max_seq_length):
630646
pass
647+
631648
model.setup_caches = do_nothing
632649

633650
model.forward = aoti_compiled_model
@@ -702,7 +719,9 @@ def do_nothing(max_batch_size, max_seq_length):
702719
logger = SingletonLogger.get_logger()
703720

704721
gpu_memory_monitor = GPUMemoryMonitor("cuda")
705-
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
722+
logger.info(
723+
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
724+
)
706725

707726
# Model-level config
708727
if builder_args.params_table:
@@ -713,20 +732,16 @@ def do_nothing(max_batch_size, max_seq_length):
713732
config = TransformerArgs.from_params(model_config.transformer_args["text"])
714733
logger.info(f"Transformer Config: {config}")
715734

716-
#TODO: Move into head of file after solving circular import
717-
from torchchat.distributed.checkpoint_utils import (
718-
load_model_weights,
719-
)
735+
# TODO: Move into head of file after solving circular import
736+
from torchchat.distributed.checkpoint_utils import load_model_weights
720737

721738
# Validate pipeline degree
722739
assert config.n_layers % pp_degree == 0
723740

724741
# Create device mesh
725742
device_mesh = dist.init_device_mesh(
726-
"cuda",
727-
(pp_degree, tp_degree),
728-
mesh_dim_names=("pp", "tp")
729-
)
743+
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
744+
)
730745
tp_mesh = device_mesh["tp"]
731746
pp_mesh = device_mesh["pp"]
732747
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
@@ -755,7 +770,13 @@ def do_nothing(max_batch_size, max_seq_length):
755770
# Load weights
756771
logger.info(f"Loading weights for {pp_rank=} on {device=}")
757772
with CUDATrackTime() as timer:
758-
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
773+
load_model_weights(
774+
model,
775+
builder_args.distribution_path,
776+
device,
777+
config,
778+
builder_args.chpt_from,
779+
)
759780

760781
logger.info(
761782
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -769,7 +790,7 @@ def do_nothing(max_batch_size, max_seq_length):
769790
# lanes.
770791
# TODO: bump up the lane count
771792
pipeline_lanes = 1
772-
seqlen_prefill=1024
793+
seqlen_prefill = 1024
773794
with device:
774795
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
775796

torchchat/usages/eval.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7-
from typing import Callable, Optional, Dict, List
7+
from typing import Callable, Dict, List, Optional
88

99
import torch
1010
import torch._dynamo.config
@@ -30,25 +30,24 @@
3030

3131
import lm_eval
3232

33+
import PIL
34+
3335
from lm_eval.evaluator import evaluate
36+
from lm_eval.models.hf_vlms import HFMultimodalLM
3437
from lm_eval.models.huggingface import HFLM as eval_wrapper
3538
from lm_eval.tasks import get_task_dict
36-
from lm_eval.models.hf_vlms import HFMultimodalLM
37-
from lm_eval.evaluator import evaluate
38-
39-
from torchtune.modules.common_utils import local_kv_cache
40-
from torchtune.modules.model_fusion import DeepFusionModel
41-
from torchtune.modules.transforms import Transform
39+
from torchtune import utils
4240
from torchtune.data import (
4341
format_content_with_images,
4442
left_pad_sequence,
4543
Message,
4644
padded_collate_tiled_images_and_mask,
4745
)
4846
from torchtune.generation import generate, sample
49-
from torchtune import utils
5047

51-
import PIL
48+
from torchtune.modules.common_utils import local_kv_cache
49+
from torchtune.modules.model_fusion import DeepFusionModel
50+
from torchtune.modules.transforms import Transform
5251

5352

5453
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
@@ -428,7 +427,6 @@ def _model_multimodal_generate(
428427
return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0)
429428

430429

431-
432430
@torch.no_grad()
433431
def eval(
434432
model: Model,
@@ -492,7 +490,8 @@ def multi_model_eval(
492490
limit: Optional[int] = None,
493491
max_seq_length: Optional[int] = None,
494492
device: str = "cpu",
495-
is_pte_model: bool = False,):
493+
is_pte_model: bool = False,
494+
):
496495
"""
497496
Evaluates a language model on a specified task using the lm-evaluation-harness library.
498497
@@ -513,7 +512,7 @@ def multi_model_eval(
513512

514513
model_eval_wrapper = VLMEvalWrapper(
515514
model,
516-
transform=tokenizer, # tranform is the tokenizer for multimodal models
515+
transform=tokenizer, # tranform is the tokenizer for multimodal models
517516
max_seq_length=max_seq_length,
518517
device=device,
519518
)
@@ -557,7 +556,10 @@ def main(args) -> None:
557556

558557
modality = builder_args.modality
559558
print(f"Modality of model={modality}")
560-
assert modality in ["text", "text-image"], "Only text and text-plus-image modality is supported for evaluation"
559+
assert modality in [
560+
"text",
561+
"text-image",
562+
], "Only text and text-image modality is supported for evaluation"
561563

562564
print(f"Using device={device}")
563565
set_precision(builder_args.precision)
@@ -575,12 +577,16 @@ def main(args) -> None:
575577

576578
if compile:
577579
assert not (
578-
builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path
580+
builder_args.dso_path
581+
or builder_args.pte_path
582+
or builder_args.aoti_package_path
579583
), "cannot compile exported model"
580584
model_forward = torch.compile(
581585
model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True
582586
)
583-
torch._inductor.config.coordinate_descent_tuning = False if device == "cpu" else True
587+
torch._inductor.config.coordinate_descent_tuning = (
588+
False if device == "cpu" else True
589+
)
584590

585591
with measure_time("Time to run eval: {time:.02f}s."):
586592
if modality == "text":

0 commit comments

Comments
 (0)