@@ -1183,3 +1183,114 @@ def forward(
11831183 if pv .Version (transformers .__version__ ) < pv .Version ("4.53.99" ):
11841184 return attn_output , attn_weights , past_key_value
11851185 return attn_output , attn_weights
1186+
1187+ class patched_SamMaskDecoder (torch .nn .Module ):
1188+ _PATCHES_ = ["forward" ]
1189+ _PATCHED_CLASS_ = transformers .models .sam .modeling_sam .SamMaskDecoder
1190+
1191+ def forward (
1192+ self ,
1193+ image_embeddings : torch .Tensor ,
1194+ image_positional_embeddings : torch .Tensor ,
1195+ sparse_prompt_embeddings : torch .Tensor ,
1196+ dense_prompt_embeddings : torch .Tensor ,
1197+ multimask_output : bool ,
1198+ output_attentions : Optional [bool ] = None ,
1199+ attention_similarity : Optional [torch .Tensor ] = None ,
1200+ target_embedding : Optional [torch .Tensor ] = None ,
1201+ ) -> tuple [torch .Tensor , torch .Tensor ]:
1202+ """
1203+ Predict masks given image and prompt embeddings.
1204+
1205+ Args:
1206+ image_embeddings (`torch.Tensor`):
1207+ the embeddings from the image encoder
1208+ image_positional_embedding (`torch.Tensor`):
1209+ positional encoding with the shape of image_embeddings
1210+ sparse_prompt_embeddings (`torch.Tensor`):
1211+ The embeddings of the points and boxes
1212+ dense_prompt_embeddings (`torch.Tensor`):
1213+ the embeddings of the mask inputs
1214+ multimask_output (bool):
1215+ Whether to return multiple masks or a single mask.
1216+ output_attentions (bool, *optional*):
1217+ Whether or not to return the attentions tensors of all attention layers.
1218+ """
1219+ batch_size , num_channels , height , width = image_embeddings .shape
1220+ point_batch_size = sparse_prompt_embeddings .shape [1 ]
1221+ # Concatenate output tokens
1222+ output_tokens = torch .cat ([self .iou_token .weight , self .mask_tokens .weight ], dim = 0 )
1223+ output_tokens = output_tokens .repeat (batch_size , point_batch_size , 1 , 1 )
1224+
1225+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1226+ # torch.any is needed to avoid data-dependent control flow
1227+ # with sparse_prompt_embeddings.sum().item() != 0
1228+ def sparse_prompt_embeddings_is_not_empty (output_tokens , sparse_prompt_embeddings ):
1229+ return torch .cat ((output_tokens , sparse_prompt_embeddings ), dim = 2 )
1230+ def sparse_prompt_embeddings_is_empty (output_tokens , sparse_prompt_embeddings ):
1231+ return output_tokens .clone ()
1232+
1233+ tokens = torch .cond (
1234+ torch .any (sparse_prompt_embeddings != 0 ),
1235+ sparse_prompt_embeddings_is_not_empty ,
1236+ sparse_prompt_embeddings_is_empty ,
1237+ [output_tokens , sparse_prompt_embeddings ],
1238+ )
1239+
1240+ point_embeddings = tokens .to (self .iou_token .weight .dtype )
1241+
1242+ # Expand per-image data in batch direction to be per-point
1243+ image_embeddings = image_embeddings + dense_prompt_embeddings
1244+ image_embeddings = image_embeddings .repeat_interleave (point_batch_size , 0 )
1245+ image_positional_embeddings = image_positional_embeddings .repeat_interleave (point_batch_size , 0 )
1246+
1247+ # Run the transformer, image_positional_embedding are consumed
1248+ point_embedding , image_embeddings , attentions = self .transformer (
1249+ point_embeddings = point_embeddings ,
1250+ image_embeddings = image_embeddings ,
1251+ image_positional_embeddings = image_positional_embeddings ,
1252+ attention_similarity = attention_similarity ,
1253+ target_embedding = target_embedding ,
1254+ output_attentions = output_attentions ,
1255+ )
1256+ iou_token_out = point_embedding [:, :, 0 , :]
1257+ mask_tokens_out = point_embedding [:, :, 1 : (1 + self .num_mask_tokens ), :]
1258+
1259+ # Upscale mask embeddings and predict masks using the mask tokens
1260+ image_embeddings = image_embeddings .transpose (2 , 3 ).reshape (
1261+ batch_size * point_batch_size , num_channels , height , width
1262+ )
1263+
1264+ upscaled_embedding = self .upscale_conv1 (image_embeddings )
1265+ upscaled_embedding = self .activation (self .upscale_layer_norm (upscaled_embedding ))
1266+ upscaled_embedding = self .activation (self .upscale_conv2 (upscaled_embedding ))
1267+
1268+ hyper_in_list = []
1269+ for i in range (self .num_mask_tokens ):
1270+ current_mlp = self .output_hypernetworks_mlps [i ]
1271+ hyper_in_list += [current_mlp (mask_tokens_out [:, :, i , :])]
1272+ hyper_in = torch .stack (hyper_in_list , dim = 2 )
1273+
1274+ _ , num_channels , height , width = upscaled_embedding .shape
1275+ upscaled_embedding = upscaled_embedding .reshape (batch_size , point_batch_size , num_channels , height * width )
1276+ masks = (hyper_in @ upscaled_embedding ).reshape (batch_size , point_batch_size , - 1 , height , width )
1277+
1278+ # Generate mask quality predictions
1279+ iou_pred = self .iou_prediction_head (iou_token_out )
1280+
1281+ # Select the correct mask or masks for output
1282+ if multimask_output :
1283+ mask_slice = slice (1 , None )
1284+ else :
1285+ mask_slice = slice (0 , 1 )
1286+ masks = masks [:, :, mask_slice , :, :]
1287+ iou_pred = iou_pred [:, :, mask_slice ]
1288+
1289+ outputs = (masks , iou_pred )
1290+
1291+ if output_attentions :
1292+ outputs = outputs + (attentions ,)
1293+ else :
1294+ outputs = outputs + (None ,)
1295+
1296+ return outputs
0 commit comments