Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 26c1d8b

Browse files
vmpurivmpuriJack-Khuu
authored
Torchchat CLI pipeline for Multimodal Models (#1140)
* Torchchat CLI pipeline for Multimodal Models * Remove torchaudio check; we don't use it * Flip the imports back for ET --------- Co-authored-by: vmpuri <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
1 parent 6fae164 commit 26c1d8b

File tree

5 files changed

+251
-90
lines changed

5 files changed

+251
-90
lines changed

.github/workflows/pull.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ jobs:
458458
pip3 list
459459
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
460460
python3 -c 'import torchvision;print(f"torchvision: {torchvision.__version__, torchvision.version.git_version}")'
461-
python3 -c 'import torchaudio;print(f"torchaudio: {torchaudio.__version__, torchaudio.version.git_version}")'
462461
463462
cd ../..
464463
echo "Inside: ${PWD}"

torchchat/cli/builder.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,26 @@
1515
import torch._dynamo.config
1616
import torch._inductor.config
1717
import torch.nn as nn
18+
19+
try:
20+
from _torchchat_test_script import flamingo_meta_to_tune
21+
except ImportError:
22+
pass
23+
1824
from distributed import (
1925
init_distributed,
2026
launch_distributed,
2127
ParallelDims,
2228
parallelize_llama,
2329
)
30+
2431
from torch.distributed.device_mesh import DeviceMesh
2532

26-
from torchchat.model import Model
33+
from torchtune.models.convert_weights import meta_to_tune
34+
35+
from torchtune.training import set_default_dtype
36+
37+
from torchchat.model import Model, ModelType
2738

2839
from torchchat.model_config.model_config import resolve_model_config
2940
from torchchat.utils.build_utils import (
@@ -35,10 +46,6 @@
3546
from torchchat.utils.measure_time import measure_time
3647
from torchchat.utils.quantize import quantize_model
3748

38-
from torchtune.models.convert_weights import meta_to_tune
39-
40-
41-
4249

4350
@dataclass
4451
class BuilderArgs:
@@ -143,7 +150,6 @@ def from_args(cls, args): # -> BuilderArgs:
143150
if "chat" in path_basename or "instruct" in path_basename:
144151
is_chat_model = True
145152

146-
147153
output_pte_path = getattr(args, "output_pte_path", None)
148154
output_dso_path = getattr(args, "output_dso_path", None)
149155
if output_pte_path and args.dtype.startswith("fast"):
@@ -234,7 +240,12 @@ def validate_model(
234240

235241
is_tiktoken = self.is_tiktoken
236242
is_sentencepiece = self.is_sentencepiece
237-
use_tiktoken = model.config.transformer_args["text"].use_tiktoken
243+
text_args = model.config.transformer_args.get("text")
244+
if text_args is None:
245+
# TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo
246+
use_tiktoken = model.config.model_type == ModelType.Flamingo
247+
else:
248+
use_tiktoken = text_args.use_tiktoken
238249

239250
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
240251
raise RuntimeError(
@@ -266,7 +277,9 @@ def from_args(cls, args): # -> TokenizerArgs:
266277
raise RuntimeError("cannot find tokenizer model")
267278

268279
if not tokenizer_path.is_file():
269-
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
280+
raise RuntimeError(
281+
f"did not find tokenizer at {os.path.abspath(tokenizer_path)}"
282+
)
270283

271284
return cls(
272285
tokenizer_path=tokenizer_path,
@@ -335,7 +348,9 @@ def _load_model_default(builder_args, only_config=False):
335348

336349
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
337350
print("Loading Tune checkpoint")
338-
meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
351+
meta_checkpoint = torch.load(
352+
str(builder_args.checkpoint_path), mmap=True, weights_only=True
353+
)
339354
checkpoint = meta_to_tune(meta_checkpoint)
340355
elif builder_args.checkpoint_dir is not None:
341356
# Load multiple checkpoint; ignore the single path.
@@ -372,8 +387,17 @@ def _load_model_default(builder_args, only_config=False):
372387
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
373388
checkpoint = checkpoint["model"]
374389

375-
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
376-
model.load_state_dict(checkpoint, assign=True, strict=True)
390+
if model.config.model_type == ModelType.Flamingo:
391+
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
392+
with set_default_dtype(builder_args.precision), torch.device(
393+
builder_args.device
394+
):
395+
model = Model.from_params(builder_args.params_path)
396+
state_dict = flamingo_meta_to_tune(checkpoint)
397+
model.model.load_state_dict(state_dict)
398+
else:
399+
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
400+
model.load_state_dict(checkpoint, assign=True, strict=True)
377401

378402
return model
379403

torchchat/cli/cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def check_args(args, verb: str) -> None:
4646
# different semantics.
4747
if (
4848
verb not in INVENTORY_VERBS
49-
and args.model
49+
and getattr(args, "model", None)
5050
and not is_model_downloaded(args.model, args.model_directory)
5151
):
5252
download_and_convert(args.model, args.model_directory, args.hf_token)
@@ -320,6 +320,13 @@ def _add_generation_args(parser, verb: str) -> None:
320320
help="Number of samples",
321321
)
322322

323+
generator_parser.add_argument(
324+
"--image-prompts",
325+
nargs="+",
326+
type=str,
327+
default=None,
328+
help="Paths to image files used as image prompts for multimodal models. Currently, 1 image input is supported.",
329+
)
323330
generator_parser.add_argument(
324331
"--chat",
325332
action="store_true",

0 commit comments

Comments
 (0)