|
19 | 19 |
|
20 | 20 | from PIL import Image |
21 | 21 |
|
22 | | -from torchtune.data import Message, padded_collate |
| 22 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 23 | + |
23 | 24 | from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
24 | 25 |
|
25 | 26 | from torchchat.cli.download import is_model_downloaded, load_model_configs |
@@ -288,13 +289,50 @@ def __init__(self, *args, **kwargs): |
288 | 289 | else self.model.text_transformer_args.max_seq_length |
289 | 290 | ) |
290 | 291 | except: |
291 | | - # can not find max_seq_length in model config, use default value |
292 | | - self.max_seq_length = 128 |
| 292 | + self.max_seq_length = 2048 |
| 293 | + print(f"can not find max_seq_length in model config, use default value: {self.max_seq_length}") |
293 | 294 | # The System fingerprint is a unique identifier for the model and its configuration. |
294 | 295 | self.system_fingerprint = ( |
295 | 296 | f"{self.builder_args.device}_{self.builder_args.precision}" |
296 | 297 | ) |
297 | 298 |
|
| 299 | + def _openai_messages_to_torchtune_messages( |
| 300 | + self, messages: List[_AbstractMessage] |
| 301 | + ) -> List[Message]: |
| 302 | + """Convert a list of OpenAI API messages to a list of TorchTune messages. |
| 303 | +
|
| 304 | + Args: |
| 305 | + messages: A list of OpenAI API messages. |
| 306 | +
|
| 307 | + Returns: |
| 308 | + A list of Torchtune Messages. |
| 309 | + """ |
| 310 | + torchtune_messages = [] |
| 311 | + for message in messages: |
| 312 | + torchtune_contents = [] |
| 313 | + if isinstance(message["content"], list): |
| 314 | + for content_dict in message["content"]: |
| 315 | + converted_content = [] |
| 316 | + if content_dict["type"] == "text": |
| 317 | + converted_content.append( |
| 318 | + {"type": "text", "content": content_dict["text"]} |
| 319 | + ) |
| 320 | + elif content_dict["type"] == "image_url": |
| 321 | + base64_decoded = base64.b64decode( |
| 322 | + content_dict["image_url"].split(";base64,")[1] |
| 323 | + ) |
| 324 | + image = Image.open(BytesIO(base64_decoded)) |
| 325 | + converted_content.append( |
| 326 | + { |
| 327 | + "type": "image", |
| 328 | + "content": image, |
| 329 | + } |
| 330 | + ) |
| 331 | + torchtune_messages.append( |
| 332 | + Message(role=message["role"], content=converted_content, eot=False) |
| 333 | + ) |
| 334 | + return torchtune_messages |
| 335 | + |
298 | 336 | def _openai_messages_to_torchtune( |
299 | 337 | self, messages: List[_AbstractMessage] |
300 | 338 | ) -> List[Message]: |
@@ -376,15 +414,32 @@ def chunked_completion(self, completion_request: CompletionRequest): |
376 | 414 | transform = llama3_2_vision_transform( |
377 | 415 | str(self.tokenizer_args.tokenizer_path) |
378 | 416 | ) |
379 | | - torchtune_messages = self._openai_messages_to_torchtune( |
| 417 | + torchtune_messages = self._openai_messages_to_torchtune_messages( |
380 | 418 | completion_request.messages |
381 | 419 | ) |
382 | 420 | data = transform( |
383 | 421 | {"images": images, "messages": torchtune_messages}, inference=True |
384 | 422 | ) |
385 | | - batch = padded_collate([data], self.builder_args.device) |
386 | | - batch.pop("mask") |
387 | | - encoded = batch["tokens"] |
| 423 | + seq_len = len(data["tokens"]) |
| 424 | + total_response_length = seq_len + completion_request.max_tokens |
| 425 | + causal_mask = torch.tril( |
| 426 | + torch.ones( |
| 427 | + size=(total_response_length, total_response_length), |
| 428 | + dtype=torch.bool, |
| 429 | + ) |
| 430 | + ) |
| 431 | + input_pos = torch.arange(total_response_length) |
| 432 | + |
| 433 | + with torch.no_grad(): |
| 434 | + with torch.device(self.builder_args.device): |
| 435 | + batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) |
| 436 | + batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision) |
| 437 | + batch["causal_mask"] = causal_mask |
| 438 | + batch["input_pos"] = input_pos[None, :seq_len] |
| 439 | + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] |
| 440 | + |
| 441 | + #batch = padded_collate([data], self.builder_args.device) |
| 442 | + encoded = batch["tokens"].view(-1) |
388 | 443 | else: |
389 | 444 | tokens = self.chat_formatter.encode_dialog_prompt( |
390 | 445 | dialog=[ |
|
0 commit comments