Skip to content

Commit 8eb4d1c

Browse files
committed
add patch
1 parent 88b4f89 commit 8eb4d1c

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,3 +1183,114 @@ def forward(
11831183
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
11841184
return attn_output, attn_weights, past_key_value
11851185
return attn_output, attn_weights
1186+
1187+
class patched_SamMaskDecoder(torch.nn.Module):
1188+
_PATCHES_ = ["forward"]
1189+
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
1190+
1191+
def forward(
1192+
self,
1193+
image_embeddings: torch.Tensor,
1194+
image_positional_embeddings: torch.Tensor,
1195+
sparse_prompt_embeddings: torch.Tensor,
1196+
dense_prompt_embeddings: torch.Tensor,
1197+
multimask_output: bool,
1198+
output_attentions: Optional[bool] = None,
1199+
attention_similarity: Optional[torch.Tensor] = None,
1200+
target_embedding: Optional[torch.Tensor] = None,
1201+
) -> tuple[torch.Tensor, torch.Tensor]:
1202+
"""
1203+
Predict masks given image and prompt embeddings.
1204+
1205+
Args:
1206+
image_embeddings (`torch.Tensor`):
1207+
the embeddings from the image encoder
1208+
image_positional_embedding (`torch.Tensor`):
1209+
positional encoding with the shape of image_embeddings
1210+
sparse_prompt_embeddings (`torch.Tensor`):
1211+
The embeddings of the points and boxes
1212+
dense_prompt_embeddings (`torch.Tensor`):
1213+
the embeddings of the mask inputs
1214+
multimask_output (bool):
1215+
Whether to return multiple masks or a single mask.
1216+
output_attentions (bool, *optional*):
1217+
Whether or not to return the attentions tensors of all attention layers.
1218+
"""
1219+
batch_size, num_channels, height, width = image_embeddings.shape
1220+
point_batch_size = sparse_prompt_embeddings.shape[1]
1221+
# Concatenate output tokens
1222+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1223+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
1224+
1225+
# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1226+
# torch.any is needed to avoid data-dependent control flow
1227+
# with sparse_prompt_embeddings.sum().item() != 0
1228+
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
1229+
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
1230+
def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
1231+
return output_tokens.clone()
1232+
1233+
tokens = torch.cond(
1234+
torch.any(sparse_prompt_embeddings != 0),
1235+
sparse_prompt_embeddings_is_not_empty,
1236+
sparse_prompt_embeddings_is_empty,
1237+
[output_tokens, sparse_prompt_embeddings],
1238+
)
1239+
1240+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
1241+
1242+
# Expand per-image data in batch direction to be per-point
1243+
image_embeddings = image_embeddings + dense_prompt_embeddings
1244+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
1245+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
1246+
1247+
# Run the transformer, image_positional_embedding are consumed
1248+
point_embedding, image_embeddings, attentions = self.transformer(
1249+
point_embeddings=point_embeddings,
1250+
image_embeddings=image_embeddings,
1251+
image_positional_embeddings=image_positional_embeddings,
1252+
attention_similarity=attention_similarity,
1253+
target_embedding=target_embedding,
1254+
output_attentions=output_attentions,
1255+
)
1256+
iou_token_out = point_embedding[:, :, 0, :]
1257+
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
1258+
1259+
# Upscale mask embeddings and predict masks using the mask tokens
1260+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
1261+
batch_size * point_batch_size, num_channels, height, width
1262+
)
1263+
1264+
upscaled_embedding = self.upscale_conv1(image_embeddings)
1265+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
1266+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
1267+
1268+
hyper_in_list = []
1269+
for i in range(self.num_mask_tokens):
1270+
current_mlp = self.output_hypernetworks_mlps[i]
1271+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
1272+
hyper_in = torch.stack(hyper_in_list, dim=2)
1273+
1274+
_, num_channels, height, width = upscaled_embedding.shape
1275+
upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
1276+
masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
1277+
1278+
# Generate mask quality predictions
1279+
iou_pred = self.iou_prediction_head(iou_token_out)
1280+
1281+
# Select the correct mask or masks for output
1282+
if multimask_output:
1283+
mask_slice = slice(1, None)
1284+
else:
1285+
mask_slice = slice(0, 1)
1286+
masks = masks[:, :, mask_slice, :, :]
1287+
iou_pred = iou_pred[:, :, mask_slice]
1288+
1289+
outputs = (masks, iou_pred)
1290+
1291+
if output_attentions:
1292+
outputs = outputs + (attentions,)
1293+
else:
1294+
outputs = outputs + (None,)
1295+
1296+
return outputs

0 commit comments

Comments
 (0)