This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
Multiturn mm single image #1270
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
53f9d34
initial test
Jack-Khuu 25beb26
Pad casual mask with zeroes and set decoder max_seq_len to the max se…
61d1e0e
Merge branch 'main' into multiturn-mm-single-image
Jack-Khuu 91a68ab
Merge branch 'main' into multiturn-mm-single-image
Jack-Khuu ad84f51
Fix control bug for image inputs
26a99fc
Clear image input after submitting a chat
1abd632
Include empty assistant message for chat
d494aa0
Pipe image input from CLI
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| import argparse | ||
| import base64 | ||
| import itertools | ||
| import logging | ||
| import os | ||
|
|
@@ -12,6 +13,7 @@ | |
|
|
||
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
| from io import BytesIO | ||
| from os import PathLike | ||
| from pathlib import Path | ||
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | ||
|
|
@@ -600,9 +602,8 @@ def generate( | |
|
|
||
| if len(prompt.shape) > 1: | ||
| prompt = prompt.squeeze(0) | ||
| T = prompt.size(0) | ||
| max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T) | ||
| T_new = T + max_new_tokens | ||
| prompt_length = prompt.size(0) | ||
| max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length) | ||
| # set up caches only if first inference | ||
| if start_pos == 0: | ||
| model = model.to(device=device) | ||
|
|
@@ -616,7 +617,7 @@ def generate( | |
| batch_size=1, | ||
| dtype=self.dtype, | ||
| encoder_max_seq_len=6404, | ||
| decoder_max_seq_len=T_new, | ||
| decoder_max_seq_len=max_seq_length, | ||
| ) | ||
| else: | ||
| model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) | ||
|
|
@@ -629,7 +630,7 @@ def generate( | |
| model.reset_caches() | ||
|
|
||
| input_pos = torch.arange( | ||
| start_pos, T + start_pos, device=device, dtype=torch.int | ||
| start_pos, prompt_length + start_pos, device=device, dtype=torch.int | ||
| ) | ||
|
|
||
| prefill_t0 = time.perf_counter() | ||
|
|
@@ -655,7 +656,9 @@ def generate( | |
| # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens(). | ||
| callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2) | ||
|
|
||
| input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int) | ||
| input_pos = torch.tensor( | ||
| [start_pos + prompt_length], device=device, dtype=torch.int | ||
| ) | ||
| accept_counts = [0] * ( | ||
| speculate_k + 1 | ||
| ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long | ||
|
|
@@ -678,7 +681,7 @@ def generate( | |
| ) | ||
|
|
||
| accept_counts[len(next_tokens) - 1] += 1 | ||
| num_added = min(T_new - input_pos - 1, len(next_tokens)) | ||
| num_added = min(max_new_tokens - input_pos - 1, len(next_tokens)) | ||
| for token in next_tokens[:num_added,]: | ||
| callback(token) | ||
| yield token, None | ||
|
|
@@ -741,6 +744,7 @@ def _gen_model_input( | |
| prompt: Union[str | List[Any]], | ||
| image_prompts: Optional[List[str | Image.Image]] = None, | ||
| max_new_tokens: Optional[int] = None, | ||
| max_seq_len: Optional[int] = 2048, | ||
| ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: | ||
| """ | ||
| Convert prompt and image prompts into consumable model input args. | ||
|
|
@@ -757,7 +761,7 @@ def _gen_model_input( | |
| Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models. | ||
| """ | ||
|
|
||
| # Not Llama 3.2 11B | ||
| # Text-Only model | ||
| if self.model.config.model_type != ModelType.Flamingo: | ||
| # Single String prompt | ||
| if isinstance(prompt, str): | ||
|
|
@@ -778,32 +782,69 @@ def _gen_model_input( | |
| assert ( | ||
| image_prompts is None or len(image_prompts) == 1 | ||
| ), "At most one image is supported at the moment" | ||
|
|
||
| if image_prompts and isinstance(image_prompts[0], str): | ||
| images = [Image.open(image_prompts[0])] | ||
| else: | ||
| images = image_prompts | ||
| images = None | ||
|
|
||
| assert ( | ||
| max_new_tokens is not None | ||
| ), "max_new_tokens must be specified for Flamingo models" | ||
| assert isinstance( | ||
| prompt, str | ||
| ), "(Currently) prompt must be a str for Flamingo models" | ||
|
|
||
| is_multimodal = images is not None | ||
| content = [{"type": "text", "content": prompt}] | ||
| image_found = False | ||
| messages = [] | ||
| for message in prompt: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be Since it sends You might need to "create" a container prompt with those 2
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or chat calls a curried version of this function that creates the format before calling this function |
||
| if isinstance(message["content"], str): | ||
| if not image_found and image_prompts: | ||
| messages.append( | ||
| Message( | ||
| role=message["role"], | ||
| content=[ | ||
| {"type": "image", "content": images[0]}, | ||
| {"type": "text", "content": message["content"]}, | ||
| ], | ||
| ) | ||
| ) | ||
| image_found = True | ||
| else: | ||
| messages.append(Message(**message)) | ||
|
|
||
| elif isinstance(message["content"], list): | ||
| images = None | ||
| for content_dict in message["content"]: | ||
| if content_dict["type"] == "text": | ||
| prompt_arg = content_dict["text"] | ||
| elif content_dict["type"] == "image_url": | ||
| assert ( | ||
| images is None | ||
| ), "At most one image is supported at the moment" | ||
|
|
||
| base64_decoded = base64.b64decode( | ||
| content_dict["image_url"].split(";base64,")[1] | ||
| ) | ||
| images = [Image.open(BytesIO(base64_decoded))] | ||
| image_found = True | ||
|
|
||
| is_multimodal = images is not None | ||
| content = [{"type": "text", "content": prompt_arg}] | ||
|
|
||
| if is_multimodal: | ||
| content = [{"type": "image", "content": images[0]}] + content | ||
|
|
||
| if is_multimodal: | ||
| content = [{"type": "image", "content": images[0]}] + content | ||
| messages.append( | ||
| Message( | ||
| role=message["role"], | ||
| content=content, | ||
| ) | ||
| ) | ||
|
|
||
| messages = [ | ||
| messages.append( | ||
| Message( | ||
| role="user", | ||
| content=content, | ||
| eot=True, | ||
| ), | ||
| Message(role="assistant", content=""), | ||
| ] | ||
| role="assistant", | ||
| content="", | ||
| ) | ||
| ) | ||
|
|
||
| transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path)) | ||
|
|
||
|
|
@@ -812,7 +853,7 @@ def _gen_model_input( | |
| with device, set_default_dtype(self.dtype): | ||
| data = transform({"messages": messages}, inference=True) | ||
|
|
||
| if is_multimodal: | ||
| if image_found: | ||
| batch = padded_collate_tiled_images_and_mask( | ||
| [data], pad_direction="left", pad_max_images=1 | ||
| ) | ||
|
|
@@ -822,17 +863,27 @@ def _gen_model_input( | |
| batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to( | ||
| self.dtype | ||
| ) | ||
|
|
||
| else: | ||
| encoded = torch.tensor(data["tokens"], device=device).view(-1) | ||
| seq_len = encoded.size(0) | ||
| batch = {} | ||
|
|
||
| total_response_length = seq_len + max_new_tokens | ||
| batch["causal_mask"] = torch.tril( | ||
| torch.ones( | ||
| size=(total_response_length, total_response_length), | ||
| dtype=torch.bool, | ||
| ) | ||
| batch["causal_mask"] = torch.nn.functional.pad( | ||
| torch.tril( | ||
| torch.ones( | ||
| size=(total_response_length, total_response_length), | ||
| dtype=torch.bool, | ||
| ) | ||
| ), | ||
| ( | ||
| 0, | ||
| max_seq_len - total_response_length, | ||
| 0, | ||
| max_seq_len - total_response_length, | ||
| ), | ||
| value=0, | ||
| ) | ||
|
|
||
| logging.debug(encoded) | ||
|
|
@@ -845,12 +896,6 @@ def chat( | |
| if generator_args.chat_mode: | ||
| print("Starting Interactive Chat") | ||
|
|
||
| encoded, batch = self._gen_model_input( | ||
| generator_args.prompt, | ||
| generator_args.image_prompts, | ||
| generator_args.max_new_tokens, | ||
| ) | ||
|
|
||
| model_size = sum( | ||
| [ | ||
| p.numel() * p.dtype.itemsize | ||
|
|
@@ -896,6 +941,12 @@ def chat( | |
| max_seq_length = ( | ||
| text_transformer_args.max_seq_length if text_transformer_args else 2048 | ||
| ) | ||
| encoded, batch = self._gen_model_input( | ||
| [{"role": "user", "content": generator_args.prompt}], | ||
| generator_args.image_prompts, | ||
| generator_args.max_new_tokens, | ||
| max_seq_length, | ||
| ) | ||
|
|
||
| if generator_args.chat_mode: | ||
| print( | ||
|
|
@@ -907,16 +958,16 @@ def chat( | |
| if get_system_prompt == "y" or get_system_prompt == "Y": | ||
| self.system_prompt = input("What is your system prompt? \n") | ||
|
|
||
| elif not generator_args.is_torchtune_model: | ||
| max_seq_length = min( | ||
| encoded.size(0) + generator_args.max_new_tokens, | ||
| ( | ||
| text_transformer_args.block_size | ||
| if text_transformer_args is not None | ||
| else 2048 | ||
| ), | ||
| max_seq_length, | ||
| ) | ||
| # elif not generator_args.is_torchtune_model: | ||
| # max_seq_length = min( | ||
| # encoded.size(0) + generator_args.max_new_tokens, | ||
| # ( | ||
| # text_transformer_args.block_size | ||
| # if text_transformer_args is not None | ||
| # else 2048 | ||
| # ), | ||
| # max_seq_length, | ||
| # ) | ||
|
|
||
| max_seq_length = ( | ||
| max_seq_length + self.speculative_builder_args.speculate_k + 1 | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this line necessary to remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it would be problem with long prompt_lengths