Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 621cf5d

Browse files
author
vmpuri
committed
Torchchat CLI pipeline for Multimodal Models
1 parent ab6fb9b commit 621cf5d

File tree

4 files changed

+250
-89
lines changed

4 files changed

+250
-89
lines changed

torchchat/cli/builder.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,25 @@
1515
import torch._dynamo.config
1616
import torch._inductor.config
1717
import torch.nn as nn
18+
try:
19+
from _torchchat_test_script import flamingo_meta_to_tune
20+
except ImportError:
21+
pass
22+
1823
from distributed import (
1924
init_distributed,
2025
launch_distributed,
2126
ParallelDims,
2227
parallelize_llama,
2328
)
29+
2430
from torch.distributed.device_mesh import DeviceMesh
2531

26-
from torchchat.model import Model
32+
from torchtune.models.convert_weights import meta_to_tune
33+
34+
from torchtune.training import set_default_dtype
35+
36+
from torchchat.model import Model, ModelType
2737

2838
from torchchat.model_config.model_config import resolve_model_config
2939
from torchchat.utils.build_utils import (
@@ -35,10 +45,6 @@
3545
from torchchat.utils.measure_time import measure_time
3646
from torchchat.utils.quantize import quantize_model
3747

38-
from torchtune.models.convert_weights import meta_to_tune
39-
40-
41-
4248

4349
@dataclass
4450
class BuilderArgs:
@@ -143,7 +149,6 @@ def from_args(cls, args): # -> BuilderArgs:
143149
if "chat" in path_basename or "instruct" in path_basename:
144150
is_chat_model = True
145151

146-
147152
output_pte_path = getattr(args, "output_pte_path", None)
148153
output_dso_path = getattr(args, "output_dso_path", None)
149154
if output_pte_path and args.dtype.startswith("fast"):
@@ -234,7 +239,12 @@ def validate_model(
234239

235240
is_tiktoken = self.is_tiktoken
236241
is_sentencepiece = self.is_sentencepiece
237-
use_tiktoken = model.config.transformer_args["text"].use_tiktoken
242+
text_args = model.config.transformer_args.get("text")
243+
if text_args is None:
244+
# TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo
245+
use_tiktoken = model.config.model_type == ModelType.Flamingo
246+
else:
247+
use_tiktoken = text_args.use_tiktoken
238248

239249
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
240250
raise RuntimeError(
@@ -266,7 +276,9 @@ def from_args(cls, args): # -> TokenizerArgs:
266276
raise RuntimeError("cannot find tokenizer model")
267277

268278
if not tokenizer_path.is_file():
269-
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
279+
raise RuntimeError(
280+
f"did not find tokenizer at {os.path.abspath(tokenizer_path)}"
281+
)
270282

271283
return cls(
272284
tokenizer_path=tokenizer_path,
@@ -335,7 +347,9 @@ def _load_model_default(builder_args, only_config=False):
335347

336348
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
337349
print("Loading Tune checkpoint")
338-
meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
350+
meta_checkpoint = torch.load(
351+
str(builder_args.checkpoint_path), mmap=True, weights_only=True
352+
)
339353
checkpoint = meta_to_tune(meta_checkpoint)
340354
elif builder_args.checkpoint_dir is not None:
341355
# Load multiple checkpoint; ignore the single path.
@@ -372,8 +386,17 @@ def _load_model_default(builder_args, only_config=False):
372386
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
373387
checkpoint = checkpoint["model"]
374388

375-
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
376-
model.load_state_dict(checkpoint, assign=True, strict=True)
389+
if model.config.model_type == ModelType.Flamingo:
390+
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
391+
with set_default_dtype(builder_args.precision), torch.device(
392+
builder_args.device
393+
):
394+
model = Model.from_params(builder_args.params_path)
395+
state_dict = flamingo_meta_to_tune(checkpoint)
396+
model.model.load_state_dict(state_dict)
397+
else:
398+
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
399+
model.load_state_dict(checkpoint, assign=True, strict=True)
377400

378401
return model
379402

torchchat/cli/cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def check_args(args, verb: str) -> None:
4646
# different semantics.
4747
if (
4848
verb not in INVENTORY_VERBS
49-
and args.model
49+
and getattr(args, "model", None)
5050
and not is_model_downloaded(args.model, args.model_directory)
5151
):
5252
download_and_convert(args.model, args.model_directory, args.hf_token)
@@ -320,6 +320,13 @@ def _add_generation_args(parser, verb: str) -> None:
320320
help="Number of samples",
321321
)
322322

323+
generator_parser.add_argument(
324+
"--image-prompts",
325+
nargs="+",
326+
type=str,
327+
default=None,
328+
help="Paths to image files used as image prompts for multimodal models. Currently, 1 image input is supported.",
329+
)
323330
generator_parser.add_argument(
324331
"--chat",
325332
action="store_true",

0 commit comments

Comments
 (0)