@@ -1184,6 +1184,7 @@ def forward(
11841184 return attn_output , attn_weights , past_key_value
11851185 return attn_output , attn_weights
11861186
1187+
11871188class patched_SamMaskDecoder (torch .nn .Module ):
11881189 _PATCHES_ = ["forward" ]
11891190 _PATCHED_CLASS_ = transformers .models .sam .modeling_sam .SamMaskDecoder
@@ -1223,10 +1224,11 @@ def forward(
12231224 output_tokens = output_tokens .repeat (batch_size , point_batch_size , 1 , 1 )
12241225
12251226 # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1226- # torch.any is needed to avoid data-dependent control flow
1227+ # torch.any is needed to avoid data-dependent control flow
12271228 # with sparse_prompt_embeddings.sum().item() != 0
12281229 def sparse_prompt_embeddings_is_not_empty (output_tokens , sparse_prompt_embeddings ):
12291230 return torch .cat ((output_tokens , sparse_prompt_embeddings ), dim = 2 )
1231+
12301232 def sparse_prompt_embeddings_is_empty (output_tokens , sparse_prompt_embeddings ):
12311233 return output_tokens .clone ()
12321234
@@ -1242,7 +1244,9 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
12421244 # Expand per-image data in batch direction to be per-point
12431245 image_embeddings = image_embeddings + dense_prompt_embeddings
12441246 image_embeddings = image_embeddings .repeat_interleave (point_batch_size , 0 )
1245- image_positional_embeddings = image_positional_embeddings .repeat_interleave (point_batch_size , 0 )
1247+ image_positional_embeddings = image_positional_embeddings .repeat_interleave (
1248+ point_batch_size , 0
1249+ )
12461250
12471251 # Run the transformer, image_positional_embedding are consumed
12481252 point_embedding , image_embeddings , attentions = self .transformer (
@@ -1272,8 +1276,12 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
12721276 hyper_in = torch .stack (hyper_in_list , dim = 2 )
12731277
12741278 _ , num_channels , height , width = upscaled_embedding .shape
1275- upscaled_embedding = upscaled_embedding .reshape (batch_size , point_batch_size , num_channels , height * width )
1276- masks = (hyper_in @ upscaled_embedding ).reshape (batch_size , point_batch_size , - 1 , height , width )
1279+ upscaled_embedding = upscaled_embedding .reshape (
1280+ batch_size , point_batch_size , num_channels , height * width
1281+ )
1282+ masks = (hyper_in @ upscaled_embedding ).reshape (
1283+ batch_size , point_batch_size , - 1 , height , width
1284+ )
12771285
12781286 # Generate mask quality predictions
12791287 iou_pred = self .iou_prediction_head (iou_token_out )
@@ -1289,8 +1297,8 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
12891297 outputs = (masks , iou_pred )
12901298
12911299 if output_attentions :
1292- outputs = outputs + (attentions ,)
1300+ outputs = outputs + (attentions ,) # noqa: RUF005
12931301 else :
1294- outputs = outputs + (None ,)
1302+ outputs = outputs + (None ,) # noqa: RUF005
12951303
1296- return outputs
1304+ return outputs
0 commit comments