Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 987269b

Browse files
author
vmpuri
committed
Torchchat CLI pipeline for Multimodal Models
1 parent ab6fb9b commit 987269b

File tree

4 files changed

+242
-86
lines changed

4 files changed

+242
-86
lines changed

torchchat/cli/builder.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@
1515
import torch._dynamo.config
1616
import torch._inductor.config
1717
import torch.nn as nn
18+
from _torchchat_test_script import flamingo_meta_to_tune
1819
from distributed import (
1920
init_distributed,
2021
launch_distributed,
2122
ParallelDims,
2223
parallelize_llama,
2324
)
25+
2426
from torch.distributed.device_mesh import DeviceMesh
2527

26-
from torchchat.model import Model
28+
from torchtune.models.convert_weights import meta_to_tune
29+
30+
from torchtune.training import set_default_dtype
31+
32+
from torchchat.model import Model, ModelType
2733

2834
from torchchat.model_config.model_config import resolve_model_config
2935
from torchchat.utils.build_utils import (
@@ -35,10 +41,6 @@
3541
from torchchat.utils.measure_time import measure_time
3642
from torchchat.utils.quantize import quantize_model
3743

38-
from torchtune.models.convert_weights import meta_to_tune
39-
40-
41-
4244

4345
@dataclass
4446
class BuilderArgs:
@@ -143,7 +145,6 @@ def from_args(cls, args): # -> BuilderArgs:
143145
if "chat" in path_basename or "instruct" in path_basename:
144146
is_chat_model = True
145147

146-
147148
output_pte_path = getattr(args, "output_pte_path", None)
148149
output_dso_path = getattr(args, "output_dso_path", None)
149150
if output_pte_path and args.dtype.startswith("fast"):
@@ -234,7 +235,11 @@ def validate_model(
234235

235236
is_tiktoken = self.is_tiktoken
236237
is_sentencepiece = self.is_sentencepiece
237-
use_tiktoken = model.config.transformer_args["text"].use_tiktoken
238+
text_args = model.config.transformer_args.get("text")
239+
if text_args is None:
240+
use_tiktoken = model.config.model_type == ModelType.Flamingo
241+
else:
242+
use_tiktoken = text_args.use_tiktoken
238243

239244
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
240245
raise RuntimeError(
@@ -266,7 +271,9 @@ def from_args(cls, args): # -> TokenizerArgs:
266271
raise RuntimeError("cannot find tokenizer model")
267272

268273
if not tokenizer_path.is_file():
269-
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
274+
raise RuntimeError(
275+
f"did not find tokenizer at {os.path.abspath(tokenizer_path)}"
276+
)
270277

271278
return cls(
272279
tokenizer_path=tokenizer_path,
@@ -335,7 +342,9 @@ def _load_model_default(builder_args, only_config=False):
335342

336343
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
337344
print("Loading Tune checkpoint")
338-
meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
345+
meta_checkpoint = torch.load(
346+
str(builder_args.checkpoint_path), mmap=True, weights_only=True
347+
)
339348
checkpoint = meta_to_tune(meta_checkpoint)
340349
elif builder_args.checkpoint_dir is not None:
341350
# Load multiple checkpoint; ignore the single path.
@@ -372,8 +381,16 @@ def _load_model_default(builder_args, only_config=False):
372381
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
373382
checkpoint = checkpoint["model"]
374383

375-
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
376-
model.load_state_dict(checkpoint, assign=True, strict=True)
384+
if model.config.model_type == ModelType.Flamingo:
385+
with set_default_dtype(builder_args.precision), torch.device(
386+
builder_args.device
387+
):
388+
model = Model.from_params(builder_args.params_path)
389+
state_dict = flamingo_meta_to_tune(checkpoint)
390+
model.model.load_state_dict(state_dict)
391+
else:
392+
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
393+
model.load_state_dict(checkpoint, assign=True, strict=True)
377394

378395
return model
379396

torchchat/cli/cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def check_args(args, verb: str) -> None:
4646
# different semantics.
4747
if (
4848
verb not in INVENTORY_VERBS
49-
and args.model
49+
and vars(args).get("model", None)
5050
and not is_model_downloaded(args.model, args.model_directory)
5151
):
5252
download_and_convert(args.model, args.model_directory, args.hf_token)
@@ -320,6 +320,13 @@ def _add_generation_args(parser, verb: str) -> None:
320320
help="Number of samples",
321321
)
322322

323+
generator_parser.add_argument(
324+
"--image-prompts",
325+
nargs="+",
326+
type=str,
327+
default=None,
328+
help="Paths to image files used as image prompts for multimodal models.",
329+
)
323330
generator_parser.add_argument(
324331
"--chat",
325332
action="store_true",

0 commit comments

Comments
 (0)