|  | 
| 19 | 19 | 
 | 
| 20 | 20 | from PIL import Image | 
| 21 | 21 | 
 | 
| 22 |  | -from torchtune.data import Message, padded_collate | 
|  | 22 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask | 
|  | 23 | + | 
| 23 | 24 | from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform | 
| 24 | 25 | 
 | 
| 25 | 26 | from torchchat.cli.download import is_model_downloaded, load_model_configs | 
| @@ -295,6 +296,42 @@ def __init__(self, *args, **kwargs): | 
| 295 | 296 |             f"{self.builder_args.device}_{self.builder_args.precision}" | 
| 296 | 297 |         ) | 
| 297 | 298 | 
 | 
|  | 299 | +    def _openai_messages_to_torchtune_messages(  | 
|  | 300 | +        self, messages: List[_AbstractMessage] | 
|  | 301 | +    ) -> List[Message]: | 
|  | 302 | +        """Convert a list of OpenAI API messages to a list of TorchTune messages. | 
|  | 303 | +
 | 
|  | 304 | +        Args: | 
|  | 305 | +            messages: A list of OpenAI API messages. | 
|  | 306 | +
 | 
|  | 307 | +        Returns: | 
|  | 308 | +            A list of Torchtune Messages. | 
|  | 309 | +        """ | 
|  | 310 | +        torchtune_messages = [] | 
|  | 311 | +        for message in messages: | 
|  | 312 | +            torchtune_contents = [] | 
|  | 313 | +            if isinstance(message["content"], list): | 
|  | 314 | +                for content_dict in message["content"]: | 
|  | 315 | +                    converted_content = [] | 
|  | 316 | +                    if content_dict["type"] == "text": | 
|  | 317 | +                        converted_content.append( | 
|  | 318 | +                            {"type": "text", "content": content_dict["text"]} | 
|  | 319 | +                        ) | 
|  | 320 | +                    elif content_dict["type"] == "image_url": | 
|  | 321 | +                        base64_decoded = base64.b64decode( | 
|  | 322 | +                                    content_dict["image_url"].split(";base64,")[1] | 
|  | 323 | +                                ) | 
|  | 324 | +                        converted_content.append( | 
|  | 325 | +                            { | 
|  | 326 | +                                "type": "image", | 
|  | 327 | +                                "content": Image.open(BytesIO(base64_decoded)), | 
|  | 328 | +                            } | 
|  | 329 | +                        ) | 
|  | 330 | +                    torchtune_messages.append(  | 
|  | 331 | +                        Message(role=message["role"], content=converted_content, eot=False) | 
|  | 332 | +                    ) | 
|  | 333 | +        return torchtune_messages | 
|  | 334 | + | 
| 298 | 335 |     def _openai_messages_to_torchtune( | 
| 299 | 336 |         self, messages: List[_AbstractMessage] | 
| 300 | 337 |     ) -> List[Message]: | 
| @@ -376,15 +413,32 @@ def chunked_completion(self, completion_request: CompletionRequest): | 
| 376 | 413 |             transform = llama3_2_vision_transform( | 
| 377 | 414 |                 str(self.tokenizer_args.tokenizer_path) | 
| 378 | 415 |             ) | 
| 379 |  | -            torchtune_messages = self._openai_messages_to_torchtune( | 
|  | 416 | +            torchtune_messages = self._openai_messages_to_torchtune_messages( | 
| 380 | 417 |                 completion_request.messages | 
| 381 | 418 |             ) | 
| 382 | 419 |             data = transform( | 
| 383 | 420 |                 {"images": images, "messages": torchtune_messages}, inference=True | 
| 384 | 421 |             ) | 
| 385 |  | -            batch = padded_collate([data], self.builder_args.device) | 
| 386 |  | -            batch.pop("mask") | 
| 387 |  | -            encoded = batch["tokens"] | 
|  | 422 | +            seq_len = len(data["tokens"]) | 
|  | 423 | +            total_response_length = seq_len + completion_request.max_tokens | 
|  | 424 | +            causal_mask = torch.tril( | 
|  | 425 | +                torch.ones( | 
|  | 426 | +                    size=(total_response_length, total_response_length), | 
|  | 427 | +                    dtype=torch.bool, | 
|  | 428 | +                ) | 
|  | 429 | +            ) | 
|  | 430 | +            input_pos = torch.arange(total_response_length) | 
|  | 431 | + | 
|  | 432 | +            with torch.no_grad(): | 
|  | 433 | +                with torch.device(self.builder_args.device): | 
|  | 434 | +                    batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) | 
|  | 435 | +                    batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision) | 
|  | 436 | +                    batch["causal_mask"] = causal_mask | 
|  | 437 | +                    batch["input_pos"] = input_pos[None, :seq_len] | 
|  | 438 | +                    batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] | 
|  | 439 | + | 
|  | 440 | +                    #batch = padded_collate([data], self.builder_args.device) | 
|  | 441 | +            encoded = batch["tokens"].view(-1) | 
| 388 | 442 |         else: | 
| 389 | 443 |             tokens = self.chat_formatter.encode_dialog_prompt( | 
| 390 | 444 |                 dialog=[ | 
|  | 
0 commit comments