Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 14502bf

Browse files
author
anirudh
committed
revert changes from builder.py
1 parent 51135fd commit 14502bf

File tree

1 file changed

+37
-79
lines changed

1 file changed

+37
-79
lines changed

torchchat/cli/builder.py

Lines changed: 37 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,13 @@
1616
import torch._inductor.config
1717
import torch.distributed as dist
1818

19-
from torchtune.models.convert_weights import meta_to_tune
20-
21-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
22-
23-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
24-
25-
from torchtune.training import set_default_dtype
26-
27-
from torchchat.distributed.logging_utils import SingletonLogger
28-
29-
from torchchat.distributed.utils import (
19+
from torchchat.distributed.utils import(
3020
Color as color,
3121
CUDATrackTime,
32-
GPUMemoryMonitor,
3322
init_distributed,
23+
GPUMemoryMonitor,
3424
)
25+
from torchchat.distributed.logging_utils import SingletonLogger
3526

3627
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
3728
from torchchat.model_config.model_config import resolve_model_config
@@ -45,6 +36,15 @@
4536
from torchchat.utils.quantize import quantize_model
4637

4738

39+
from torchtune.models.convert_weights import meta_to_tune
40+
41+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
42+
43+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
44+
45+
from torchtune.training import set_default_dtype
46+
47+
4848
@dataclass
4949
class BuilderArgs:
5050
checkpoint_path: Optional[Union[Path, str]] = None
@@ -189,19 +189,15 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
189189
tp = getattr(args, "tp", 1)
190190
chpt_from = getattr(args, "chpt_from", "hf")
191191
sdp_backend_dict = {
192-
"math": torch.nn.attention.SDPBackend.MATH,
193-
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
194-
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
195-
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
192+
'math': torch.nn.attention.SDPBackend.MATH,
193+
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
194+
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
195+
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
196196
}
197197
attention_backend = sdp_backend_dict[args.attention_backend]
198-
if args.device == "cpu" and (
199-
args.attention_backend == "efficient_attention"
200-
or args.attention_backend == "cudnn_attention"
201-
):
202-
print(
203-
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
204-
)
198+
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
199+
or args.attention_backend == "cudnn_attention"):
200+
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
205201
attention_backend = torch.nn.attention.SDPBackend.MATH
206202
return cls(
207203
checkpoint_dir=checkpoint_dir,
@@ -250,29 +246,13 @@ class TokenizerArgs:
250246
is_sentencepiece: bool = False
251247
is_tiktoken: bool = False
252248
is_hf_tokenizer: bool = False
253-
is_llama_3_2_mm: bool = False
254249
t: Optional[Any] = None
255250

256251
def __post_init__(self):
257-
# special handling for llama-3.2-mm
258-
if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower():
259-
try:
260-
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
261-
262-
self.t = llama3_2_vision_transform(path=str(self.tokenizer_path))
263-
self.is_llama_3_2_mm = True
264-
self.is_tiktoken = False
265-
self.is_sentencepiece = False
266-
self.is_hf_tokenizer = False
267-
return
268-
except:
269-
pass
270-
271252
try:
272253
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
273254

274255
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
275-
self.is_llama_3_2_mm = False
276256
self.is_tiktoken = True
277257
self.is_sentencepiece = False
278258
self.is_hf_tokenizer = False
@@ -284,7 +264,6 @@ def __post_init__(self):
284264
from sentencepiece import SentencePieceProcessor
285265

286266
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
287-
self.is_llama_3_2_mm = False
288267
self.is_tiktoken = False
289268
self.is_sentencepiece = True
290269
self.is_hf_tokenizer = False
@@ -296,15 +275,13 @@ def __post_init__(self):
296275
from tokenizer.hf_tokenizer import HFTokenizer
297276

298277
self.t = HFTokenizer(str(self.tokenizer_path))
299-
self.is_llama_3_2_mm = False
300278
self.is_tiktoken = False
301279
self.is_sentencepiece = False
302280
self.is_hf_tokenizer = True
303281
return
304282
except:
305283
pass
306284

307-
self.is_llama_3_2_mm = False
308285
self.is_tiktoken = False
309286
self.is_sentencepiece = False
310287
self.is_hf_tokenizer = False
@@ -319,32 +296,20 @@ def validate_model(
319296
if model is None:
320297
return
321298

322-
if (
323-
sum(
324-
[
325-
self.is_tiktoken,
326-
self.is_hf_tokenizer,
327-
self.is_sentencepiece,
328-
self.is_llama_3_2_mm,
329-
]
330-
)
331-
!= 1
332-
):
299+
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
333300
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
334301

335302
is_tiktoken = self.is_tiktoken
336303
is_sentencepiece = self.is_sentencepiece
337304
is_hf_tokenizer = self.is_hf_tokenizer
338-
is_llama_3_2_mm = self.is_llama_3_2_mm
339-
340305
use_tiktoken = model.config.use_tiktoken
341306
use_hf_tokenizer = model.config.use_hf_tokenizer
342-
use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer)
307+
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
308+
343309
if (
344-
(is_tiktoken and not use_tiktoken)
345-
or (is_hf_tokenizer and not use_hf_tokenizer)
346-
or (is_sentencepiece and not use_other_tokenizer)
347-
or (is_llama_3_2_mm and not use_other_tokenizer)
310+
(is_tiktoken and not use_tiktoken) or
311+
(is_hf_tokenizer and not use_hf_tokenizer) or
312+
(is_sentencepiece and not use_sentencepiece)
348313
):
349314
raise RuntimeError(
350315
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
@@ -542,7 +507,6 @@ def _load_model(builder_args: BuilderArgs) -> Model:
542507
# AOTI-compoiled model will load its own weights.
543508
# Release weights here to avoid OOM
544509
import gc
545-
546510
if hasattr(model, "model"):
547511
model.model = None
548512
gc.collect()
@@ -600,7 +564,6 @@ def _initialize_model(
600564

601565
def do_nothing(max_batch_size, max_seq_length):
602566
pass
603-
604567
model.setup_caches = do_nothing
605568

606569
model.forward = torch._export.aot_load(
@@ -638,7 +601,6 @@ def do_nothing(max_batch_size, max_seq_length):
638601

639602
def do_nothing(max_batch_size, max_seq_length):
640603
pass
641-
642604
model.setup_caches = do_nothing
643605

644606
model.forward = aoti_compiled_model
@@ -713,9 +675,7 @@ def do_nothing(max_batch_size, max_seq_length):
713675
logger = SingletonLogger.get_logger()
714676

715677
gpu_memory_monitor = GPUMemoryMonitor("cuda")
716-
logger.info(
717-
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
718-
)
678+
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
719679

720680
# Model-level config
721681
if builder_args.params_table:
@@ -726,16 +686,20 @@ def do_nothing(max_batch_size, max_seq_length):
726686
config = TransformerArgs.from_params(model_config.transformer_args["text"])
727687
logger.info(f"Transformer Config: {config}")
728688

729-
# TODO: Move into head of file after solving circular import
730-
from torchchat.distributed.checkpoint_utils import load_model_weights
689+
#TODO: Move into head of file after solving circular import
690+
from torchchat.distributed.checkpoint_utils import (
691+
load_model_weights,
692+
)
731693

732694
# Validate pipeline degree
733695
assert config.n_layers % pp_degree == 0
734696

735697
# Create device mesh
736698
device_mesh = dist.init_device_mesh(
737-
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
738-
)
699+
"cuda",
700+
(pp_degree, tp_degree),
701+
mesh_dim_names=("pp", "tp")
702+
)
739703
tp_mesh = device_mesh["tp"]
740704
pp_mesh = device_mesh["pp"]
741705
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
@@ -764,13 +728,7 @@ def do_nothing(max_batch_size, max_seq_length):
764728
# Load weights
765729
logger.info(f"Loading weights for {pp_rank=} on {device=}")
766730
with CUDATrackTime() as timer:
767-
load_model_weights(
768-
model,
769-
builder_args.distribution_path,
770-
device,
771-
config,
772-
builder_args.chpt_from,
773-
)
731+
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
774732

775733
logger.info(
776734
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -784,7 +742,7 @@ def do_nothing(max_batch_size, max_seq_length):
784742
# lanes.
785743
# TODO: bump up the lane count
786744
pipeline_lanes = 1
787-
seqlen_prefill = 1024
745+
seqlen_prefill=1024
788746
with device:
789747
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
790748

@@ -836,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
836794
return "TikToken"
837795
if tokenizers:
838796
return "Tokenizers"
839-
return "SentencePiece"
797+
return "SentencePiece"

0 commit comments

Comments
 (0)