@@ -1300,14 +1300,19 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
13001300 )
13011301
13021302 # Run the transformer, image_positional_embedding are consumed
1303- point_embedding , image_embeddings , attentions = self .transformer (
1303+ torch ._check (point_embeddings .shape [0 ] != 0 )
1304+ torch ._check (point_embeddings .shape [1 ] != 0 )
1305+ torch ._check (point_embeddings .shape [2 ] != 0 )
1306+ torch ._check (point_embeddings .shape [3 ] != 0 )
1307+ embeddings_attentions = self .transformer (
13041308 point_embeddings = point_embeddings ,
13051309 image_embeddings = image_embeddings ,
13061310 image_positional_embeddings = image_positional_embeddings ,
13071311 attention_similarity = attention_similarity ,
13081312 target_embedding = target_embedding ,
13091313 output_attentions = output_attentions ,
13101314 )
1315+ point_embedding , image_embeddings = embeddings_attentions [:2 ]
13111316 iou_token_out = torch .select (point_embedding , dim = 2 , index = 0 )
13121317 mask_tokens_out = torch .narrow (
13131318 point_embedding , dim = 2 , start = 1 , length = self .num_mask_tokens
@@ -1349,9 +1354,12 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
13491354
13501355 outputs = (masks , iou_pred )
13511356
1352- if output_attentions :
1353- outputs = outputs + (attentions ,) # noqa: RUF005
1357+ if len (embeddings_attentions ) == 2 :
1358+ # transformers==4.54
1359+ return outputs
1360+
1361+ if output_attentions and len (embeddings_attentions ) > 2 :
1362+ outputs = outputs + (embeddings_attentions [2 ],) # noqa: RUF005
13541363 else :
13551364 outputs = outputs + (None ,) # noqa: RUF005
1356-
13571365 return outputs
0 commit comments