44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import base64
78import os
89import time
910import uuid
1011
1112from abc import ABC
1213from dataclasses import dataclass
14+ from io import BytesIO
1315from pwd import getpwuid
1416from typing import Any , Dict , List , Optional , Union
1517
1618import torch
1719
20+ from _torchchat_test_script import flamingo_transform , padded_collate
21+ from PIL import Image
22+ from torchtune .data import Message
23+
1824from torchchat .cli .download import is_model_downloaded , load_model_configs
1925from torchchat .generate import Generator , GeneratorArgs
2026
3137# Message classes and associated objects - see the types of Messages under "Create Chat Completion >>> Request body >>> messages"
3238
3339
40+ @dataclass
41+ class _ContentPart (ABC ):
42+ """A single part of a message content field.
43+
44+ See the "Assistants >>> Messages >>> Create Message >>> Request body >>> content >>> Show possible types" section of the OpenAI API docs for more details.
45+ """
46+
47+ type : str
48+
49+
50+ @dataclass
51+ class ImageFile :
52+ file_id : str
53+ detail : Optional [str ]
54+
55+
56+ @dataclass
57+ class ImageFileContentPart (_ContentPart ):
58+ type : str = "image_file"
59+ image_file : Optional [ImageFile ] = None
60+
61+
62+ @dataclass
63+ class ImageUrl :
64+ url : str
65+ detail : Optional [str ]
66+
67+
68+ @dataclass
69+ class ImageUrlContentPart (_ContentPart ):
70+ type : str = "image_url"
71+ image_url : Optional [ImageUrl ] = None
72+
73+
74+ @dataclass
75+ class TextContentPart (_ContentPart ):
76+ text : str = ""
77+ type : str = "text"
78+
79+
3480@dataclass
3581class _AbstractMessage (ABC ):
3682 """Base class with common parameters for message types.
@@ -42,7 +88,7 @@ class _AbstractMessage(ABC):
4288 """
4389
4490 role : str
45- content : Optional [str ] = None
91+ content : Optional [Union [ List [ _ContentPart ], str ] ] = None
4692
4793
4894@dataclass
@@ -185,7 +231,7 @@ class ChunkDelta:
185231
186232 tool_calls : Optional [List [ToolCall ]]
187233 role : Optional [str ]
188- content : Optional [str ]
234+ content : Optional [Union [ List [ _ContentPart ], str ]] = None
189235
190236
191237@dataclass
@@ -232,18 +278,55 @@ def __init__(self, *args, **kwargs):
232278 """
233279
234280 super ().__init__ (* args , ** kwargs )
235- self .max_seq_length = (
236- self .model .config .transformer_args ["text" ].max_seq_length
237- + self .speculative_builder_args .speculate_k
238- + 1
239- if self .draft_model is not None
240- else self .model .config .transformer_args ["text" ].max_seq_length
241- )
281+ self .max_seq_length = 128
282+ if self .model .config .transformer_args .get ("text" , None ):
283+ self .max_seq_len = (
284+ self .model .config .transformer_args ["text" ].max_seq_length
285+ + self .speculative_builder_args .speculate_k
286+ + 1
287+ if self .draft_model is not None
288+ else self .model .config .transformer_args ["text" ].max_seq_length
289+ )
242290 # The System fingerprint is a unique identifier for the model and its configuration.
243291 self .system_fingerprint = (
244292 f"{ self .builder_args .device } _{ self .builder_args .precision } "
245293 )
246294
295+ def _openai_messages_to_torchtune (self , messages : List [_AbstractMessage ]):
296+ """Convert a list of OpenAI API messages to a list of TorchTune messages.
297+
298+ Args:
299+ messages: A list of OpenAI API messages.
300+
301+ Returns:
302+ A list of Torchtune Messages.
303+ """
304+ torchtune_messages = []
305+ for message in messages :
306+ torchtune_contents = []
307+ if isinstance (message ["content" ], list ):
308+ for content in message ["content" ]:
309+ if isinstance (content , dict ):
310+ if content ["type" ] == "image_url" :
311+ torchtune_contents .append ({"type" : "image" })
312+ elif content ["type" ] == "image_file" :
313+ torchtune_contents .append ({"type" : "image" })
314+ elif content ["type" ] == "text" :
315+ torchtune_contents .append (
316+ {"type" : "text" , "content" : content ["text" ]}
317+ )
318+ elif isinstance (content , str ):
319+ torchtune_contents .append ({"type" : "text" , "text" : content })
320+ else :
321+ torchtune_contents .append (
322+ {"type" : "text" , "content" : message ["content" ]}
323+ )
324+ torchtune_messages .append (
325+ Message (role = message ["role" ], content = torchtune_contents , eot = False )
326+ )
327+ torchtune_messages .append (Message (role = "assistant" , content = "" , eot = False ))
328+ return torchtune_messages
329+
247330 def chunked_completion (self , completion_request : CompletionRequest ):
248331 """Handle a chat completion request and yield a chunked response.
249332
@@ -271,15 +354,42 @@ def chunked_completion(self, completion_request: CompletionRequest):
271354 id = str (uuid .uuid4 ())
272355
273356 idx = 0
274- tokens = self .chat_formatter .encode_dialog_prompt (
275- dialog = [
276- {"role" : message ["role" ], "content" : message ["content" ]}
277- for message in completion_request .messages
278- ]
279- )
357+ images = []
280358
281- encoded = torch .tensor (tokens , dtype = torch .int , device = self .builder_args .device )
282- print (self .tokenizer .decode (tokens ))
359+ device_sync (device = self .builder_args .device )
360+ for message in completion_request .messages :
361+ contents = message ["content" ]
362+ if isinstance (contents , list ):
363+ for content in message ["content" ]:
364+ if content ["type" ] == "image_url" :
365+ base64_decoded = base64 .b64decode (
366+ content ["image_url" ].split (";base64," )[1 ]
367+ )
368+ images .append (Image .open (BytesIO (base64_decoded )))
369+ print ("images:" , len (images ), flush = True )
370+ if len (images ) > 0 :
371+ transform = flamingo_transform (str (self .tokenizer_args .tokenizer_path ))
372+ torchtune_messages = self ._openai_messages_to_torchtune (
373+ completion_request .messages
374+ )
375+ data = transform (
376+ {"images" : images , "messages" : torchtune_messages }, inference = True
377+ )
378+ batch = padded_collate ([data ], self .builder_args .device )
379+ batch .pop ("mask" )
380+ encoded = batch ["tokens" ]
381+ else :
382+ tokens = self .chat_formatter .encode_dialog_prompt (
383+ dialog = [
384+ {"role" : message ["role" ], "content" : message ["content" ]}
385+ for message in completion_request .messages
386+ ]
387+ )
388+ print ("tokens:" , self .tokenizer .decode (tokens ), flush = True )
389+ encoded = torch .tensor (
390+ tokens , dtype = torch .int , device = self .builder_args .device
391+ )
392+ batch = None
283393
284394 start_pos = 0
285395
@@ -293,7 +403,7 @@ def chunked_completion(self, completion_request: CompletionRequest):
293403 encoded_prompt = encoded ,
294404 temperature = float (completion_request .temperature ),
295405 chat_mode = False ,
296- sequential_prefill = True ,
406+ sequential_prefill = False ,
297407 )
298408
299409 def callback (x , * , done_generating = False ):
@@ -313,6 +423,7 @@ def callback(x, *, done_generating=False):
313423 draft_model = self .draft_model ,
314424 speculate_k = generator_args .speculate_k ,
315425 chat_mode = generator_args .chat_mode ,
426+ batch = batch ,
316427 callback = callback ,
317428 temperature = generator_args .temperature ,
318429 top_k = generator_args .top_k ,
@@ -323,10 +434,12 @@ def callback(x, *, done_generating=False):
323434 ):
324435 if y is None :
325436 continue
437+
326438 elif y .item () == self .tokenizer .eos_id :
327439 # Stop generation if the EOS token is generated.
328440 break
329441
442+ y = y .view (- 1 )
330443 # Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
331444 content = "" .join (
332445 self .tokenizer .decode ([self .tokenizer .encode ("." )[0 ]] + y .tolist ())[1 :]
0 commit comments