|
19 | 19 |
|
20 | 20 | from PIL import Image |
21 | 21 |
|
22 | | -from torchtune.data import Message, padded_collate_tiled_images_and_mask |
23 | | - |
24 | | -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
25 | | - |
26 | 22 | from torchchat.cli.download import is_model_downloaded, load_model_configs |
27 | 23 | from torchchat.generate import Generator, GeneratorArgs |
28 | 24 |
|
29 | 25 | from torchchat.utils.build_utils import device_sync |
30 | 26 |
|
| 27 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 28 | + |
| 29 | +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
| 30 | + |
31 | 31 |
|
32 | 32 | """Dataclasses defined around the objects used the OpenAI API Chat specification. |
33 | 33 |
|
@@ -296,79 +296,44 @@ def __init__(self, *args, **kwargs): |
296 | 296 | f"{self.builder_args.device}_{self.builder_args.precision}" |
297 | 297 | ) |
298 | 298 |
|
299 | | - def _openai_messages_to_torchtune_messages( |
300 | | - self, messages: List[_AbstractMessage] |
| 299 | + def _gen_model_inputs_from_openai_completion_request( |
| 300 | + self, completion_request: CompletionRequest |
301 | 301 | ) -> List[Message]: |
302 | | - """Convert a list of OpenAI API messages to a list of TorchTune messages. |
| 302 | + """Generate model inputs from an OpenAI completion request. |
303 | 303 |
|
304 | 304 | Args: |
305 | | - messages: A list of OpenAI API messages. |
| 305 | + completion_request: Request object with prompt and other parameters. |
306 | 306 |
|
307 | 307 | Returns: |
308 | | - A list of Torchtune Messages. |
| 308 | + Modle inputs. |
309 | 309 | """ |
310 | | - torchtune_messages = [] |
| 310 | + messages = completion_request.messages |
| 311 | + |
| 312 | + prompt = None |
| 313 | + images = None |
| 314 | + |
311 | 315 | for message in messages: |
312 | 316 | torchtune_contents = [] |
313 | 317 | if isinstance(message["content"], list): |
314 | 318 | for content_dict in message["content"]: |
315 | | - converted_content = [] |
316 | 319 | if content_dict["type"] == "text": |
317 | | - converted_content.append( |
318 | | - {"type": "text", "content": content_dict["text"]} |
319 | | - ) |
| 320 | + assert ( |
| 321 | + prompt is None |
| 322 | + ), "At most one text prompt is supported for each request" |
| 323 | + prompt = content_dict["text"] |
320 | 324 | elif content_dict["type"] == "image_url": |
| 325 | + assert ( |
| 326 | + images is None |
| 327 | + ), "At most one image is supported at the moment" |
| 328 | + |
321 | 329 | base64_decoded = base64.b64decode( |
322 | | - content_dict["image_url"].split(";base64,")[1] |
323 | | - ) |
324 | | - image = Image.open(BytesIO(base64_decoded)) |
325 | | - converted_content.append( |
326 | | - { |
327 | | - "type": "image", |
328 | | - "content": image, |
329 | | - } |
| 330 | + content_dict["image_url"].split(";base64,")[1] |
330 | 331 | ) |
331 | | - torchtune_messages.append( |
332 | | - Message(role=message["role"], content=converted_content, eot=False) |
333 | | - ) |
334 | | - return torchtune_messages |
| 332 | + images = [Image.open(BytesIO(base64_decoded))] |
335 | 333 |
|
336 | | - def _openai_messages_to_torchtune( |
337 | | - self, messages: List[_AbstractMessage] |
338 | | - ) -> List[Message]: |
339 | | - """Convert a list of OpenAI API messages to a list of TorchTune messages. |
| 334 | + assert prompt is not None, "Text prompt must be specified in the request" |
340 | 335 |
|
341 | | - Args: |
342 | | - messages: A list of OpenAI API messages. |
343 | | -
|
344 | | - Returns: |
345 | | - A list of Torchtune Messages. |
346 | | - """ |
347 | | - torchtune_messages = [] |
348 | | - for message in messages: |
349 | | - torchtune_contents = [] |
350 | | - if isinstance(message["content"], list): |
351 | | - for content in message["content"]: |
352 | | - if isinstance(content, dict): |
353 | | - if content["type"] == "image_url": |
354 | | - torchtune_contents.append({"type": "image"}) |
355 | | - elif content["type"] == "image_file": |
356 | | - torchtune_contents.append({"type": "image"}) |
357 | | - elif content["type"] == "text": |
358 | | - torchtune_contents.append( |
359 | | - {"type": "text", "content": content["text"]} |
360 | | - ) |
361 | | - elif isinstance(content, str): |
362 | | - torchtune_contents.append({"type": "text", "text": content}) |
363 | | - else: |
364 | | - torchtune_contents.append( |
365 | | - {"type": "text", "content": message["content"]} |
366 | | - ) |
367 | | - torchtune_messages.append( |
368 | | - Message(role=message["role"], content=torchtune_contents, eot=False) |
369 | | - ) |
370 | | - torchtune_messages.append(Message(role="assistant", content="", eot=False)) |
371 | | - return torchtune_messages |
| 336 | + return self._gen_model_input(prompt, images, completion_request.max_tokens) |
372 | 337 |
|
373 | 338 | def chunked_completion(self, completion_request: CompletionRequest): |
374 | 339 | """Handle a chat completion request and yield a chunked response. |
@@ -396,63 +361,13 @@ def chunked_completion(self, completion_request: CompletionRequest): |
396 | 361 | # Initialize counters for chunk responses and encode the prompt. |
397 | 362 | id = str(uuid.uuid4()) |
398 | 363 |
|
399 | | - idx = 0 |
400 | | - images = [] |
401 | | - |
402 | 364 | device_sync(device=self.builder_args.device) |
403 | | - for message in completion_request.messages: |
404 | | - contents = message["content"] |
405 | | - if isinstance(contents, list): |
406 | | - for content in message["content"]: |
407 | | - if content["type"] == "image_url": |
408 | | - base64_decoded = base64.b64decode( |
409 | | - content["image_url"].split(";base64,")[1] |
410 | | - ) |
411 | | - images.append(Image.open(BytesIO(base64_decoded))) |
412 | | - print("images:", len(images), flush=True) |
413 | | - if len(images) > 0: |
414 | | - transform = llama3_2_vision_transform( |
415 | | - str(self.tokenizer_args.tokenizer_path) |
416 | | - ) |
417 | | - torchtune_messages = self._openai_messages_to_torchtune_messages( |
418 | | - completion_request.messages |
419 | | - ) |
420 | | - data = transform( |
421 | | - {"images": images, "messages": torchtune_messages}, inference=True |
422 | | - ) |
423 | | - seq_len = len(data["tokens"]) |
424 | | - total_response_length = seq_len + completion_request.max_tokens |
425 | | - causal_mask = torch.tril( |
426 | | - torch.ones( |
427 | | - size=(total_response_length, total_response_length), |
428 | | - dtype=torch.bool, |
429 | | - ) |
430 | | - ) |
431 | | - input_pos = torch.arange(total_response_length) |
432 | | - |
433 | | - with torch.no_grad(): |
434 | | - with torch.device(self.builder_args.device): |
435 | | - batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) |
436 | | - batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision) |
437 | | - batch["causal_mask"] = causal_mask |
438 | | - batch["input_pos"] = input_pos[None, :seq_len] |
439 | | - batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] |
440 | | - |
441 | | - #batch = padded_collate([data], self.builder_args.device) |
442 | | - encoded = batch["tokens"].view(-1) |
443 | | - else: |
444 | | - tokens = self.chat_formatter.encode_dialog_prompt( |
445 | | - dialog=[ |
446 | | - {"role": message["role"], "content": message["content"]} |
447 | | - for message in completion_request.messages |
448 | | - ] |
449 | | - ) |
450 | | - print("tokens:", self.tokenizer.decode(tokens), flush=True) |
451 | | - encoded = torch.tensor( |
452 | | - tokens, dtype=torch.int, device=self.builder_args.device |
453 | | - ) |
454 | | - batch = None |
455 | 365 |
|
| 366 | + encoded, batch = self._gen_model_inputs_from_openai_completion_request( |
| 367 | + completion_request |
| 368 | + ) |
| 369 | + |
| 370 | + idx = 0 |
456 | 371 | start_pos = 0 |
457 | 372 |
|
458 | 373 | generator_args = GeneratorArgs( |
|
0 commit comments