Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d494aa0

Browse files
author
vmpuri
committed
Pipe image input from CLI
1 parent 1abd632 commit d494aa0

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

torchchat/generate.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def _gen_model_input(
786786
if image_prompts and isinstance(image_prompts[0], str):
787787
images = [Image.open(image_prompts[0])]
788788
else:
789-
images = image_prompts
789+
images = None
790790

791791
assert (
792792
max_new_tokens is not None
@@ -796,7 +796,19 @@ def _gen_model_input(
796796
messages = []
797797
for message in prompt:
798798
if isinstance(message["content"], str):
799-
messages.append(Message(**message))
799+
if not image_found and image_prompts:
800+
messages.append(
801+
Message(
802+
role=message["role"],
803+
content=[
804+
{"type": "image", "content": images[0]},
805+
{"type": "text", "content": message["content"]},
806+
],
807+
)
808+
)
809+
image_found = True
810+
else:
811+
messages.append(Message(**message))
800812

801813
elif isinstance(message["content"], list):
802814
images = None
@@ -816,7 +828,7 @@ def _gen_model_input(
816828

817829
is_multimodal = images is not None
818830
content = [{"type": "text", "content": prompt_arg}]
819-
831+
820832
if is_multimodal:
821833
content = [{"type": "image", "content": images[0]}] + content
822834

@@ -826,6 +838,7 @@ def _gen_model_input(
826838
content=content,
827839
)
828840
)
841+
829842
messages.append(
830843
Message(
831844
role="assistant",
@@ -929,7 +942,7 @@ def chat(
929942
text_transformer_args.max_seq_length if text_transformer_args else 2048
930943
)
931944
encoded, batch = self._gen_model_input(
932-
generator_args.prompt,
945+
[{"role": "user", "content": generator_args.prompt}],
933946
generator_args.image_prompts,
934947
generator_args.max_new_tokens,
935948
max_seq_length,
@@ -945,16 +958,16 @@ def chat(
945958
if get_system_prompt == "y" or get_system_prompt == "Y":
946959
self.system_prompt = input("What is your system prompt? \n")
947960

948-
elif not generator_args.is_torchtune_model:
949-
max_seq_length = min(
950-
encoded.size(0) + generator_args.max_new_tokens,
951-
(
952-
text_transformer_args.block_size
953-
if text_transformer_args is not None
954-
else 2048
955-
),
956-
max_seq_length,
957-
)
961+
# elif not generator_args.is_torchtune_model:
962+
# max_seq_length = min(
963+
# encoded.size(0) + generator_args.max_new_tokens,
964+
# (
965+
# text_transformer_args.block_size
966+
# if text_transformer_args is not None
967+
# else 2048
968+
# ),
969+
# max_seq_length,
970+
# )
958971

959972
max_seq_length = (
960973
max_seq_length + self.speculative_builder_args.speculate_k + 1

0 commit comments

Comments
 (0)