@@ -353,27 +353,34 @@ def prefill(
353353 width = x .size (1 )
354354 assert input_pos .size (0 ) == width
355355
356- if batch is not None :
356+ if self .model .config .model_type == ModelType .Flamingo :
357+ assert batch is not None , "Flamingo requires batch"
358+
357359 # TODO: Verify sequential prefill works with multimodal models
358- tokens = batch [ "tokens" ]
360+ is_multimodal = True
359361 if 'encoder_input' in batch :
360362 encoder_input = batch ['encoder_input' ]
363+ encoder_mask = batch ["encoder_mask" ]
364+ is_multimodal = True
361365 else :
362366 encoder_input = None
367+ encoder_mask = None
368+ is_multimodal = False
363369
364- seq_len = tokens .size (1 )
370+ seq_len = x .size (1 )
365371 mask = batch ["causal_mask" ][None , :seq_len ]
366- encoder_mask = batch ["encoder_mask" ]
367372 input_pos = input_pos .view (1 , - 1 )
368- logits = model (tokens = tokens , mask = mask , encoder_input = encoder_input , input_pos = input_pos , encoder_mask = encoder_mask )[:, - 1 ]
373+ logits = model (tokens = x , mask = mask , encoder_input = encoder_input , input_pos = input_pos , encoder_mask = encoder_mask )[:, - 1 ]
374+
375+ if is_multimodal :
376+ batch ["encoder_mask" ] = batch ["encoder_mask" ][:, - 1 :]
377+
369378 return tune_sample (logits , temperature = 0 , top_k = 500 )
370379 elif sequential_prefill :
371380 for i in range (width ):
372381 x_sliced , ip_sliced = x [:, i ].view (- 1 , 1 ), input_pos [i ].view (- 1 )
373382 # logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
374- logits = model (x_sliced , ip_sliced ) # (x[:, i], input_pos[i])
375- elif self .model .config .model_type == ModelType .Flamingo :
376- assert False , "Flamingo requires batch"
383+ logits = model (x_sliced , ip_sliced ) # (x[:, i], input_pos[i])da
377384 else :
378385 # input_pos: [B, S]
379386 logits = model (x , input_pos )
@@ -397,7 +404,7 @@ def decode_one_token(
397404 if model .config .model_type == ModelType .Flamingo :
398405 assert batch is not None , "Flamingo requires batch"
399406 mask = batch ["causal_mask" ][None , input_pos .item (), None , :]
400- encoder_mask = batch ["encoder_mask" ][:, - 1 :]
407+ encoder_mask = batch ["encoder_mask" ] if "encoder_mask" in batch else None
401408 logits = model (x , encoder_mask = encoder_mask , mask = mask , input_pos = input_pos )[:, - 1 :]
402409 else :
403410 logits = model (x , input_pos )
@@ -733,41 +740,56 @@ def chat(
733740 if generator_args .chat_mode :
734741 print ("Starting Interactive Chat" )
735742
736- if generator_args .image_prompts is not None :
737- print ("Image prompts" , generator_args .image_prompts )
743+ if self .model .config .model_type == ModelType .Flamingo :
744+
745+ is_multimodal = generator_args .image_prompts is not None
746+ content = [{"type" : "text" , "content" : generator_args .prompt }]
747+
748+ if is_multimodal :
749+ print ("Image prompts" , generator_args .image_prompts )
750+
751+ # Support for just the first image prompt for now
752+ images = [Image .open (generator_args .image_prompts [0 ])]
753+ content = [{"type" : "image" , "content" : images [0 ]}] + content
738754
739- # Support for just the first image prompt for now
740- images = [Image .open (generator_args .image_prompts [0 ])]
741755 messages = [
742756 Message (
743757 role = "user" ,
744- content = [
745- {"type" : "image" , "content" : images [0 ]},
746- {"type" : "text" , "content" : generator_args .prompt },
747- ],
758+ content = content ,
748759 eot = True ,
749760 ),
750761 Message (role = "assistant" , content = "" ),
751762 ]
752763
753764 transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
754765
755- with torch .device (device = self .builder_args .device ), set_default_dtype (self .dtype ):
766+ device = torch .device (device = self .builder_args .device )
767+
768+ with device , set_default_dtype (self .dtype ):
756769 data = transform ({"messages" : messages }, inference = True )
757- batch = padded_collate_tiled_images_and_mask ([data ], pad_direction = "left" , pad_max_images = 1 )
758- # set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it
759- batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
760- seq_len = len (data ["tokens" ])
770+
771+ if is_multimodal :
772+ batch = padded_collate_tiled_images_and_mask (
773+ [data ], pad_direction = "left" , pad_max_images = 1
774+ )
775+ encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
776+ seq_len = encoded .size (0 )
777+ batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
778+ batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
779+ else :
780+ encoded = torch .tensor (
781+ data ["tokens" ], device = device
782+ ).view (- 1 )
783+ seq_len = encoded .size (0 )
784+ batch = {}
785+
761786 total_response_length = seq_len + generator_args .max_new_tokens
762787 batch ["causal_mask" ] = torch .tril (
763788 torch .ones (
764789 size = (total_response_length , total_response_length ),
765790 dtype = torch .bool ,
766791 )
767792 )
768- batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
769- encoded = batch ["tokens" ].view (- 1 )
770-
771793 else :
772794 encoded = self .encode_tokens (
773795 generator_args .prompt , bos = True , device = self .builder_args .device
0 commit comments