Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 55 additions & 38 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import torch._inductor.config
import torch.distributed as dist

from torchchat.distributed.utils import(
from torchchat.distributed.logging_utils import SingletonLogger

from torchchat.distributed.utils import (
Color as color,
CUDATrackTime,
init_distributed,
GPUMemoryMonitor,
init_distributed,
)
from torchchat.distributed.logging_utils import SingletonLogger

from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
from torchchat.model_config.model_config import resolve_model_config
Expand All @@ -37,15 +38,6 @@
from torchchat.utils.quantize import quantize_model


from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from torchtune.training import set_default_dtype


@dataclass
class BuilderArgs:
checkpoint_path: Optional[Union[Path, str]] = None
Expand Down Expand Up @@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
tp = getattr(args, "tp", 1)
chpt_from = getattr(args, "chpt_from", "hf")
sdp_backend_dict = {
'math': torch.nn.attention.SDPBackend.MATH,
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
"math": torch.nn.attention.SDPBackend.MATH,
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
}
attention_backend = sdp_backend_dict[args.attention_backend]
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
or args.attention_backend == "cudnn_attention"):
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
if args.device == "cpu" and (
args.attention_backend == "efficient_attention"
or args.attention_backend == "cudnn_attention"
):
print(
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
)
attention_backend = torch.nn.attention.SDPBackend.MATH
return cls(
checkpoint_dir=checkpoint_dir,
Expand Down Expand Up @@ -238,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
speculative_builder_args.pte_path = None
return speculative_builder_args


class TokenizerType(Enum):
NONE = 0
TIKTOKEN = 1
SENTENCEPIECE = 2
HF_TOKENIZER = 3


@dataclass
class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
Expand Down Expand Up @@ -307,9 +305,9 @@ def validate_model(
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)

if (
(is_tiktoken and not use_tiktoken) or
(is_hf_tokenizer and not use_hf_tokenizer) or
(is_sentencepiece and not use_sentencepiece)
(is_tiktoken and not use_tiktoken)
or (is_hf_tokenizer and not use_hf_tokenizer)
or (is_sentencepiece and not use_sentencepiece)
):
raise RuntimeError(
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
Expand Down Expand Up @@ -416,6 +414,8 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:


def _load_checkpoint(builder_args: BuilderArgs):
from torchtune.models.convert_weights import meta_to_tune
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm let's push the install further into the check on line 419.
This function would error if we don't have torchtune installed


if builder_args.params_table and builder_args.params_table.endswith("Tune"):
print("Loading Tune checkpoint")
meta_checkpoint = torch.load(
Expand Down Expand Up @@ -458,6 +458,12 @@ def _load_checkpoint(builder_args: BuilderArgs):


def _load_model_default(builder_args: BuilderArgs) -> Model:
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, we can drop this into the Flamingo conditional

from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
from torchtune.training import set_default_dtype

assert not builder_args.gguf_path

model: Model = _init_model_on_meta_device(builder_args)
Expand All @@ -470,8 +476,9 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:

if model.config.model_type == ModelType.Flamingo:
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
with set_default_dtype(builder_args.precision), torch.device(
builder_args.device
with (
set_default_dtype(builder_args.precision),
torch.device(builder_args.device),
):
# It doubles the model size the memory, with redundancies of the initialized weights.
# model = Model.from_params(builder_args.params_path)
Expand Down Expand Up @@ -507,6 +514,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
# AOTI-compoiled model will load its own weights.
# Release weights here to avoid OOM
import gc

if hasattr(model, "model"):
model.model = None
gc.collect()
Expand Down Expand Up @@ -564,6 +572,7 @@ def _initialize_model(

def do_nothing(max_batch_size, max_seq_length):
pass

model.setup_caches = do_nothing

model.forward = torch._export.aot_load(
Expand Down Expand Up @@ -601,6 +610,7 @@ def do_nothing(max_batch_size, max_seq_length):

def do_nothing(max_batch_size, max_seq_length):
pass

model.setup_caches = do_nothing

model.forward = aoti_compiled_model
Expand Down Expand Up @@ -652,12 +662,15 @@ def do_nothing(max_batch_size, max_seq_length):
try:
model = torch.load(builder_args.snapshot_path, weights_only=False)
except Exception:
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
raise RuntimeError(
f"Failed to load torchchat snapshot {builder_args.snapshot_path}"
)
# _active_backend() does not allow DSO & AOTI to be true.
# Choose either.
from torchchat.utils.build_utils import set_backend
set_backend (dso=True, pte=False, aoti_package=False)
if (model.config != config):

set_backend(dso=True, pte=False, aoti_package=False)
if model.config != config:
raise RuntimeError("loaded model architecture mismatch")
##
## import all libraries with custom kernels ans custom operators
Expand All @@ -675,7 +688,9 @@ def do_nothing(max_batch_size, max_seq_length):
logger = SingletonLogger.get_logger()

gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
logger.info(
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
)

# Model-level config
if builder_args.params_table:
Expand All @@ -686,20 +701,16 @@ def do_nothing(max_batch_size, max_seq_length):
config = TransformerArgs.from_params(model_config.transformer_args["text"])
logger.info(f"Transformer Config: {config}")

#TODO: Move into head of file after solving circular import
from torchchat.distributed.checkpoint_utils import (
load_model_weights,
)
# TODO: Move into head of file after solving circular import
from torchchat.distributed.checkpoint_utils import load_model_weights

# Validate pipeline degree
assert config.n_layers % pp_degree == 0

# Create device mesh
device_mesh = dist.init_device_mesh(
"cuda",
(pp_degree, tp_degree),
mesh_dim_names=("pp", "tp")
)
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
)
tp_mesh = device_mesh["tp"]
pp_mesh = device_mesh["pp"]
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
Expand Down Expand Up @@ -728,7 +739,13 @@ def do_nothing(max_batch_size, max_seq_length):
# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
load_model_weights(
model,
builder_args.distribution_path,
device,
config,
builder_args.chpt_from,
)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
Expand All @@ -742,7 +759,7 @@ def do_nothing(max_batch_size, max_seq_length):
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
seqlen_prefill=1024
seqlen_prefill = 1024
with device:
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)

Expand Down
Loading
Loading