Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,16 @@
# LICENSE file in the root directory of this source tree.

import argparse
import importlib.metadata
import json
import logging
import os
import sys
from pathlib import Path

import torch

from torchchat.cli.download import download_and_convert, is_model_downloaded

from torchchat.utils.build_utils import (
allowable_dtype_names,
allowable_params_table,
get_device_str,
)

logging.basicConfig(level=logging.INFO, format="%(message)s")
Expand All @@ -42,6 +38,9 @@

# Handle CLI arguments that are common to a majority of subcommands.
def check_args(args, verb: str) -> None:
# Local import to avoid unnecessary expensive imports
from torchchat.cli.download import download_and_convert, is_model_downloaded

# Handle model download. Skip this for download, since it has slightly
# different semantics.
if (
Expand Down Expand Up @@ -498,9 +497,10 @@ def _add_speculative_execution_args(parser) -> None:


def arg_init(args):
if not (torch.__version__ > "2.3"):
torch_version = importlib.metadata.version("torch")
if not torch_version or (torch_version <= "2.3"):
raise RuntimeError(
f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might defer that to commands that actually use torch, and then we can just use torch.version ? No point in raising a version issue when all we do is run help or similar non-model cli commands?

Wdyt @Jack-Khuu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like having it here, since we can catch it at the very beginning as opposed to having version checking in every file.

Doing it this way will speed up the interface by a lot. Right now --help and --command is super slow. It also bogs down CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see it either way, but I think I agree with @byjlw. I don't think the non-running commands are valuable in-and-of-themselves, so I think the argument would be that this is really a property of the tool (there is no torchchat without torch) and doing the lazy imports is just a speed optimization.

If there were standalone commands that didn't need torch (i.e. if people start using download just to fetch models, but then run them with something else), then I could see making this a soft requirement and deferring it to the commands that need it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually added what @mikekgfb is suggesting a while back, so this check being in arg_init actually does what all 3 y'all suggest (check only if needed, and check in one place)

torchchat/torchchat.py

Lines 62 to 63 in 4a7dab8

if args.command not in INVENTORY_VERBS:
args = arg_init(args)

arg_init is only called when it is pertinent

)

if sys.version_info.major != 3 or sys.version_info.minor < 10:
Expand All @@ -521,6 +521,9 @@ def arg_init(args):
raise RuntimeError("Device not supported by ExecuTorch")
args.device = "cpu"
else:
# Localized import to minimize expensive imports
from torchchat.utils.build_utils import get_device_str

args.device = get_device_str(
args.quantize.get("executor", {}).get("accelerator", args.device)
)
Expand All @@ -534,5 +537,8 @@ def arg_init(args):
vars(args)["compile_prefill"] = False

if hasattr(args, "seed") and args.seed:
# Localized import to minimize expensive imports
import torch

torch.manual_seed(args.seed)
return args
25 changes: 12 additions & 13 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,23 @@
from pathlib import Path
from typing import Optional

import torch

from torchchat.model import TransformerArgs

# support running without installing as a package
wd = Path(__file__).parent.parent
sys.path.append(str(wd.resolve()))
sys.path.append(str((wd / "build").resolve()))

from torchchat.model import ModelArgs


@torch.inference_mode()
def convert_hf_checkpoint(
*,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
remove_bin_files: bool = False,
) -> None:

# Local imports to avoid expensive imports
from torchchat.model import ModelArgs, TransformerArgs
import torch

if model_dir is None:
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
if model_name is None:
Expand Down Expand Up @@ -58,10 +56,11 @@ def convert_hf_checkpoint(
tokenizer_pth = model_dir / "original" / "tokenizer.model"
if consolidated_pth.is_file() and tokenizer_pth.is_file():
# Confirm we can load it
loaded_result = torch.load(
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
)
del loaded_result # No longer needed
with torch.inference_mode():
loaded_result = torch.load(
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
)
del loaded_result # No longer needed
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
os.rename(consolidated_pth, model_dir / "model.pth")
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
Expand Down Expand Up @@ -130,7 +129,8 @@ def load_safetensors():
state_dict = None
for loader in loaders:
try:
state_dict = loader()
with torch.inference_mode():
state_dict = loader()
break
except Exception:
continue
Expand Down Expand Up @@ -173,7 +173,6 @@ def load_safetensors():
os.remove(file)


@torch.inference_mode()
def convert_hf_checkpoint_to_tune(
*,
model_dir: Optional[Path] = None,
Expand Down
50 changes: 32 additions & 18 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,31 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

##########################################################################
### unpack packed weights ###


class _LazyImportTorch:
"""This is a wrapper around the import of torch that only performs the
import when an actual attribute is needed off of torch.
"""
@staticmethod
def __getattribute__(name: str) -> Any:
import torch
return getattr(torch, name)


# Alias torch to the lazy import
torch = _LazyImportTorch()


def unpack_packed_weights(
packed_weights: Dict[str, Any],
packed_linear: Callable,
input_dtype: torch.dtype,
input_dtype: "torch.dtype",
unpacked_dims: Tuple,
) -> torch.Tensor:
) -> "torch.Tensor":
"""Given a packed weight matrix `packed_weights`, a Callable
implementing a packed linear function for the packed format, and the
unpacked dimensions of the weights, recreate the unpacked weight
Expand Down Expand Up @@ -169,26 +182,27 @@ def name_to_dtype(name, device):
return torch.bfloat16

try:
return name_to_dtype_dict[name]
return _name_to_dtype_dict[name]()
except KeyError:
raise RuntimeError(f"unsupported dtype name {name} specified")


def allowable_dtype_names() -> List[str]:
return name_to_dtype_dict.keys()


name_to_dtype_dict = {
"fp32": torch.float,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"float": torch.float,
"half": torch.float16,
"float32": torch.float,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"fast": None,
"fast16": None,
return _name_to_dtype_dict.keys()


# NOTE: values are wrapped in lambdas to avoid proactive imports for torch
_name_to_dtype_dict = {
"fp32": lambda: torch.float,
"fp16": lambda: torch.float16,
"bf16": lambda: torch.bfloat16,
"float": lambda: torch.float,
"half": lambda: torch.float16,
"float32": lambda: torch.float,
"float16": lambda: torch.float16,
"bfloat16": lambda: torch.bfloat16,
"fast": lambda: None,
"fast16": lambda: None,
}


Expand Down
Loading