2020import torch ._dynamo .config
2121import torch ._inductor .config
2222
23- from torchtune .models .llama3_2_vision ._model_builders import llama3_2_vision_transform
24-
2523from PIL import Image
2624
25+ # torchtune model definition dependencies
26+ from torchtune .data import Message , padded_collate_tiled_images_and_mask
27+
28+ from torchtune .generation import sample as tune_sample
29+ from torchtune .models .llama3 import llama3_tokenizer
30+
31+ from torchtune .models .llama3_2_vision ._model_builders import llama3_2_vision_transform
32+ from torchtune .training import set_default_dtype
33+
2734from torchchat .cli .builder import (
2835 _initialize_model ,
2936 _initialize_tokenizer ,
3441from torchchat .utils .build_utils import device_sync , set_precision
3542from torchchat .utils .device_info import get_device_info
3643
37- # torchtune model definition dependencies
38- from torchtune .data import Message , padded_collate_tiled_images_and_mask
39-
40- from torchtune .generation import sample as tune_sample
41- from torchtune .models .llama3 import llama3_tokenizer
42- from torchtune .training import set_default_dtype
43-
4444
4545class _ChatFormatter (ABC ):
4646 def __init__ (self , tokenizer ):
@@ -357,8 +357,8 @@ def prefill(
357357
358358 # TODO: Verify sequential prefill works with multimodal models
359359 is_multimodal = True
360- if ' encoder_input' in batch :
361- encoder_input = batch [' encoder_input' ]
360+ if " encoder_input" in batch :
361+ encoder_input = batch [" encoder_input" ]
362362 encoder_mask = batch ["encoder_mask" ]
363363 is_multimodal = True
364364 else :
@@ -369,7 +369,13 @@ def prefill(
369369 seq_len = x .size (1 )
370370 mask = batch ["causal_mask" ][None , :seq_len ]
371371 input_pos = input_pos .view (1 , - 1 )
372- logits = model (tokens = x , mask = mask , encoder_input = encoder_input , input_pos = input_pos , encoder_mask = encoder_mask )[:, - 1 ]
372+ logits = model (
373+ tokens = x ,
374+ mask = mask ,
375+ encoder_input = encoder_input ,
376+ input_pos = input_pos ,
377+ encoder_mask = encoder_mask ,
378+ )[:, - 1 ]
373379
374380 if is_multimodal :
375381 batch ["encoder_mask" ] = batch ["encoder_mask" ][:, - 1 :]
@@ -404,7 +410,9 @@ def decode_one_token(
404410 assert batch is not None , "Flamingo requires batch"
405411 mask = batch ["causal_mask" ][None , input_pos .item (), None , :]
406412 encoder_mask = batch ["encoder_mask" ] if "encoder_mask" in batch else None
407- logits = model (x , encoder_mask = encoder_mask , mask = mask , input_pos = input_pos )[:, - 1 :]
413+ logits = model (
414+ x , encoder_mask = encoder_mask , mask = mask , input_pos = input_pos
415+ )[:, - 1 :]
408416 else :
409417 logits = model (x , input_pos )
410418 # print(f"x: {x},\n input_pos: {input_pos}\n")
@@ -492,7 +500,6 @@ def decode_n_tokens(
492500 next_prob .clone () if next_prob is not None else None
493501 )
494502
495-
496503 def model_forward (self , model , x , input_pos ):
497504 return model (x , input_pos )
498505
@@ -605,7 +612,12 @@ def generate(
605612 or self .model .config .model_type == ModelType .Flamingo
606613 ):
607614 # 6404 is one-gpu affordable max_seq_length for single image input
608- model .setup_caches (batch_size = 1 , dtype = self .dtype , encoder_max_seq_len = 6404 , decoder_max_seq_len = T_new )
615+ model .setup_caches (
616+ batch_size = 1 ,
617+ dtype = self .dtype ,
618+ encoder_max_seq_len = 6404 ,
619+ decoder_max_seq_len = T_new ,
620+ )
609621 else :
610622 model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
611623 if is_speculative and draft_model is not model :
@@ -731,9 +743,9 @@ def _gen_model_input(
731743 max_new_tokens : Optional [int ] = None ,
732744 ) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
733745 """
734- Convert prompt and image prompts into consumable model input args.
746+ Convert prompt and image prompts into consumable model input args.
735747
736- When prompt is a list, the anticipated format is OpenAI API Inspired:
748+ When prompt is a list, the anticipated format is OpenAI API Inspired:
737749 [ ..., {"role": message["role"], "content": message["content"]}, ...]
738750
739751 Args:
@@ -826,15 +838,18 @@ def _gen_model_input(
826838 logging .debug (encoded )
827839 return encoded , batch
828840
829-
830841 def chat (
831842 self ,
832843 generator_args : GeneratorArgs ,
833844 ):
834845 if generator_args .chat_mode :
835846 print ("Starting Interactive Chat" )
836-
837- encoded , batch = self ._gen_model_input (generator_args .prompt , generator_args .image_prompts , generator_args .max_new_tokens )
847+
848+ encoded , batch = self ._gen_model_input (
849+ generator_args .prompt ,
850+ generator_args .image_prompts ,
851+ generator_args .max_new_tokens ,
852+ )
838853
839854 model_size = sum (
840855 [
@@ -900,7 +915,7 @@ def chat(
900915 if text_transformer_args is not None
901916 else 2048
902917 ),
903- max_seq_length
918+ max_seq_length ,
904919 )
905920
906921 max_seq_length = (
0 commit comments