diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index e0e309d5b..aa63782fb 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -19,7 +19,8 @@ from PIL import Image -from torchtune.data import Message, padded_collate +from torchtune.data import Message, padded_collate_tiled_images_and_mask + from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform from torchchat.cli.download import is_model_downloaded, load_model_configs @@ -288,13 +289,50 @@ def __init__(self, *args, **kwargs): else self.model.text_transformer_args.max_seq_length ) except: - # can not find max_seq_length in model config, use default value - self.max_seq_length = 128 + self.max_seq_length = 2048 + print(f"can not find max_seq_length in model config, use default value: {self.max_seq_length}") # The System fingerprint is a unique identifier for the model and its configuration. self.system_fingerprint = ( f"{self.builder_args.device}_{self.builder_args.precision}" ) + def _openai_messages_to_torchtune_messages( + self, messages: List[_AbstractMessage] + ) -> List[Message]: + """Convert a list of OpenAI API messages to a list of TorchTune messages. + + Args: + messages: A list of OpenAI API messages. + + Returns: + A list of Torchtune Messages. + """ + torchtune_messages = [] + for message in messages: + torchtune_contents = [] + if isinstance(message["content"], list): + for content_dict in message["content"]: + converted_content = [] + if content_dict["type"] == "text": + converted_content.append( + {"type": "text", "content": content_dict["text"]} + ) + elif content_dict["type"] == "image_url": + base64_decoded = base64.b64decode( + content_dict["image_url"].split(";base64,")[1] + ) + image = Image.open(BytesIO(base64_decoded)) + converted_content.append( + { + "type": "image", + "content": image, + } + ) + torchtune_messages.append( + Message(role=message["role"], content=converted_content, eot=False) + ) + return torchtune_messages + def _openai_messages_to_torchtune( self, messages: List[_AbstractMessage] ) -> List[Message]: @@ -376,15 +414,32 @@ def chunked_completion(self, completion_request: CompletionRequest): transform = llama3_2_vision_transform( str(self.tokenizer_args.tokenizer_path) ) - torchtune_messages = self._openai_messages_to_torchtune( + torchtune_messages = self._openai_messages_to_torchtune_messages( completion_request.messages ) data = transform( {"images": images, "messages": torchtune_messages}, inference=True ) - batch = padded_collate([data], self.builder_args.device) - batch.pop("mask") - encoded = batch["tokens"] + seq_len = len(data["tokens"]) + total_response_length = seq_len + completion_request.max_tokens + causal_mask = torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + ) + ) + input_pos = torch.arange(total_response_length) + + with torch.no_grad(): + with torch.device(self.builder_args.device): + batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) + batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision) + batch["causal_mask"] = causal_mask + batch["input_pos"] = input_pos[None, :seq_len] + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + + #batch = padded_collate([data], self.builder_args.device) + encoded = batch["tokens"].view(-1) else: tokens = self.chat_formatter.encode_dialog_prompt( dialog=[