This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
OpenAI API: Changes to enable multi-modal for 3.2 11B #1211
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,8 @@ | |
|
|
||
| from PIL import Image | ||
|
|
||
| from torchtune.data import Message, padded_collate | ||
| from torchtune.data import Message, padded_collate_tiled_images_and_mask | ||
|
|
||
| from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform | ||
|
|
||
| from torchchat.cli.download import is_model_downloaded, load_model_configs | ||
|
|
@@ -288,13 +289,50 @@ def __init__(self, *args, **kwargs): | |
| else self.model.text_transformer_args.max_seq_length | ||
| ) | ||
| except: | ||
| # can not find max_seq_length in model config, use default value | ||
| self.max_seq_length = 128 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha now i can see where 128 comes from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It happens to be the same as the head_dim, very tricky to trace it LOL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks to @iseeyuan for spotting it. |
||
| self.max_seq_length = 2048 | ||
| print(f"can not find max_seq_length in model config, use default value: {self.max_seq_length}") | ||
| # The System fingerprint is a unique identifier for the model and its configuration. | ||
| self.system_fingerprint = ( | ||
| f"{self.builder_args.device}_{self.builder_args.precision}" | ||
| ) | ||
|
|
||
| def _openai_messages_to_torchtune_messages( | ||
| self, messages: List[_AbstractMessage] | ||
| ) -> List[Message]: | ||
| """Convert a list of OpenAI API messages to a list of TorchTune messages. | ||
|
|
||
| Args: | ||
| messages: A list of OpenAI API messages. | ||
|
|
||
| Returns: | ||
| A list of Torchtune Messages. | ||
| """ | ||
| torchtune_messages = [] | ||
| for message in messages: | ||
| torchtune_contents = [] | ||
| if isinstance(message["content"], list): | ||
| for content_dict in message["content"]: | ||
| converted_content = [] | ||
| if content_dict["type"] == "text": | ||
| converted_content.append( | ||
| {"type": "text", "content": content_dict["text"]} | ||
| ) | ||
| elif content_dict["type"] == "image_url": | ||
| base64_decoded = base64.b64decode( | ||
| content_dict["image_url"].split(";base64,")[1] | ||
| ) | ||
| image = Image.open(BytesIO(base64_decoded)) | ||
| converted_content.append( | ||
| { | ||
| "type": "image", | ||
| "content": image, | ||
| } | ||
| ) | ||
| torchtune_messages.append( | ||
| Message(role=message["role"], content=converted_content, eot=False) | ||
| ) | ||
| return torchtune_messages | ||
|
|
||
| def _openai_messages_to_torchtune( | ||
| self, messages: List[_AbstractMessage] | ||
| ) -> List[Message]: | ||
|
|
@@ -376,15 +414,32 @@ def chunked_completion(self, completion_request: CompletionRequest): | |
| transform = llama3_2_vision_transform( | ||
| str(self.tokenizer_args.tokenizer_path) | ||
| ) | ||
| torchtune_messages = self._openai_messages_to_torchtune( | ||
| torchtune_messages = self._openai_messages_to_torchtune_messages( | ||
| completion_request.messages | ||
| ) | ||
| data = transform( | ||
| {"images": images, "messages": torchtune_messages}, inference=True | ||
| ) | ||
| batch = padded_collate([data], self.builder_args.device) | ||
| batch.pop("mask") | ||
| encoded = batch["tokens"] | ||
| seq_len = len(data["tokens"]) | ||
| total_response_length = seq_len + completion_request.max_tokens | ||
| causal_mask = torch.tril( | ||
| torch.ones( | ||
| size=(total_response_length, total_response_length), | ||
| dtype=torch.bool, | ||
| ) | ||
| ) | ||
| input_pos = torch.arange(total_response_length) | ||
|
|
||
| with torch.no_grad(): | ||
| with torch.device(self.builder_args.device): | ||
| batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) | ||
| batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision) | ||
| batch["causal_mask"] = causal_mask | ||
| batch["input_pos"] = input_pos[None, :seq_len] | ||
| batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] | ||
|
|
||
| #batch = padded_collate([data], self.builder_args.device) | ||
| encoded = batch["tokens"].view(-1) | ||
| else: | ||
| tokens = self.chat_formatter.encode_dialog_prompt( | ||
| dialog=[ | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the things on our list is having a unification configuration system for both tune-backend model and chat-backend models to get rid of the try .. except here.