Skip to content

Commit 98a5ac7

Browse files
krammnicMark Obozov
andauthored
Move imports deeper to prepare for making torchtune an optional dependency (#1539)
* optional torchtune * lint * push deeper --------- Co-authored-by: Mark Obozov <[email protected]>
1 parent a37b08a commit 98a5ac7

File tree

4 files changed

+261
-152
lines changed

4 files changed

+261
-152
lines changed

torchchat/cli/builder.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import torch._inductor.config
1818
import torch.distributed as dist
1919

20-
from torchchat.distributed.utils import(
20+
from torchchat.distributed.logging_utils import SingletonLogger
21+
22+
from torchchat.distributed.utils import (
2123
Color as color,
2224
CUDATrackTime,
23-
init_distributed,
2425
GPUMemoryMonitor,
26+
init_distributed,
2527
)
26-
from torchchat.distributed.logging_utils import SingletonLogger
2728

2829
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2930
from torchchat.model_config.model_config import resolve_model_config
@@ -37,15 +38,6 @@
3738
from torchchat.utils.quantize import quantize_model
3839

3940

40-
from torchtune.models.convert_weights import meta_to_tune
41-
42-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
43-
44-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
45-
46-
from torchtune.training import set_default_dtype
47-
48-
4941
@dataclass
5042
class BuilderArgs:
5143
checkpoint_path: Optional[Union[Path, str]] = None
@@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
188180
tp = getattr(args, "tp", 1)
189181
chpt_from = getattr(args, "chpt_from", "hf")
190182
sdp_backend_dict = {
191-
'math': torch.nn.attention.SDPBackend.MATH,
192-
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
193-
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
194-
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
183+
"math": torch.nn.attention.SDPBackend.MATH,
184+
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
185+
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
186+
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
195187
}
196188
attention_backend = sdp_backend_dict[args.attention_backend]
197-
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
198-
or args.attention_backend == "cudnn_attention"):
199-
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
189+
if args.device == "cpu" and (
190+
args.attention_backend == "efficient_attention"
191+
or args.attention_backend == "cudnn_attention"
192+
):
193+
print(
194+
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
195+
)
200196
attention_backend = torch.nn.attention.SDPBackend.MATH
201197
return cls(
202198
checkpoint_dir=checkpoint_dir,
@@ -238,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
238234
speculative_builder_args.pte_path = None
239235
return speculative_builder_args
240236

237+
241238
class TokenizerType(Enum):
242239
NONE = 0
243240
TIKTOKEN = 1
244241
SENTENCEPIECE = 2
245242
HF_TOKENIZER = 3
246243

244+
247245
@dataclass
248246
class TokenizerArgs:
249247
tokenizer_path: Optional[Union[Path, str]] = None
@@ -307,9 +305,9 @@ def validate_model(
307305
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
308306

309307
if (
310-
(is_tiktoken and not use_tiktoken) or
311-
(is_hf_tokenizer and not use_hf_tokenizer) or
312-
(is_sentencepiece and not use_sentencepiece)
308+
(is_tiktoken and not use_tiktoken)
309+
or (is_hf_tokenizer and not use_hf_tokenizer)
310+
or (is_sentencepiece and not use_sentencepiece)
313311
):
314312
raise RuntimeError(
315313
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
@@ -417,6 +415,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
417415

418416
def _load_checkpoint(builder_args: BuilderArgs):
419417
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
418+
from torchtune.models.convert_weights import meta_to_tune
420419
print("Loading Tune checkpoint")
421420
meta_checkpoint = torch.load(
422421
str(builder_args.checkpoint_path), mmap=True, weights_only=True
@@ -469,9 +468,15 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
469468
checkpoint = checkpoint["model"]
470469

471470
if model.config.model_type == ModelType.Flamingo:
471+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
472+
from torchtune.models.llama3_2_vision._convert_weights import (
473+
llama3_vision_meta_to_tune,
474+
)
475+
from torchtune.training import set_default_dtype
472476
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
473-
with set_default_dtype(builder_args.precision), torch.device(
474-
builder_args.device
477+
with (
478+
set_default_dtype(builder_args.precision),
479+
torch.device(builder_args.device),
475480
):
476481
# It doubles the model size the memory, with redundancies of the initialized weights.
477482
# model = Model.from_params(builder_args.params_path)
@@ -507,6 +512,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
507512
# AOTI-compoiled model will load its own weights.
508513
# Release weights here to avoid OOM
509514
import gc
515+
510516
if hasattr(model, "model"):
511517
model.model = None
512518
gc.collect()
@@ -564,6 +570,7 @@ def _initialize_model(
564570

565571
def do_nothing(max_batch_size, max_seq_length):
566572
pass
573+
567574
model.setup_caches = do_nothing
568575

569576
model.forward = torch._export.aot_load(
@@ -601,6 +608,7 @@ def do_nothing(max_batch_size, max_seq_length):
601608

602609
def do_nothing(max_batch_size, max_seq_length):
603610
pass
611+
604612
model.setup_caches = do_nothing
605613

606614
model.forward = aoti_compiled_model
@@ -652,12 +660,15 @@ def do_nothing(max_batch_size, max_seq_length):
652660
try:
653661
model = torch.load(builder_args.snapshot_path, weights_only=False)
654662
except Exception:
655-
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
663+
raise RuntimeError(
664+
f"Failed to load torchchat snapshot {builder_args.snapshot_path}"
665+
)
656666
# _active_backend() does not allow DSO & AOTI to be true.
657667
# Choose either.
658668
from torchchat.utils.build_utils import set_backend
659-
set_backend (dso=True, pte=False, aoti_package=False)
660-
if (model.config != config):
669+
670+
set_backend(dso=True, pte=False, aoti_package=False)
671+
if model.config != config:
661672
raise RuntimeError("loaded model architecture mismatch")
662673
##
663674
## import all libraries with custom kernels ans custom operators
@@ -675,7 +686,9 @@ def do_nothing(max_batch_size, max_seq_length):
675686
logger = SingletonLogger.get_logger()
676687

677688
gpu_memory_monitor = GPUMemoryMonitor("cuda")
678-
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
689+
logger.info(
690+
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
691+
)
679692

680693
# Model-level config
681694
if builder_args.params_table:
@@ -686,20 +699,16 @@ def do_nothing(max_batch_size, max_seq_length):
686699
config = TransformerArgs.from_params(model_config.transformer_args["text"])
687700
logger.info(f"Transformer Config: {config}")
688701

689-
#TODO: Move into head of file after solving circular import
690-
from torchchat.distributed.checkpoint_utils import (
691-
load_model_weights,
692-
)
702+
# TODO: Move into head of file after solving circular import
703+
from torchchat.distributed.checkpoint_utils import load_model_weights
693704

694705
# Validate pipeline degree
695706
assert config.n_layers % pp_degree == 0
696707

697708
# Create device mesh
698709
device_mesh = dist.init_device_mesh(
699-
"cuda",
700-
(pp_degree, tp_degree),
701-
mesh_dim_names=("pp", "tp")
702-
)
710+
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
711+
)
703712
tp_mesh = device_mesh["tp"]
704713
pp_mesh = device_mesh["pp"]
705714
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
@@ -728,7 +737,13 @@ def do_nothing(max_batch_size, max_seq_length):
728737
# Load weights
729738
logger.info(f"Loading weights for {pp_rank=} on {device=}")
730739
with CUDATrackTime() as timer:
731-
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
740+
load_model_weights(
741+
model,
742+
builder_args.distribution_path,
743+
device,
744+
config,
745+
builder_args.chpt_from,
746+
)
732747

733748
logger.info(
734749
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -742,7 +757,7 @@ def do_nothing(max_batch_size, max_seq_length):
742757
# lanes.
743758
# TODO: bump up the lane count
744759
pipeline_lanes = 1
745-
seqlen_prefill=1024
760+
seqlen_prefill = 1024
746761
with device:
747762
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
748763

0 commit comments

Comments
 (0)