Skip to content

Commit 34f5ed6

Browse files
committed
link
1 parent 8eb4d1c commit 34f5ed6

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

onnx_diagnostic/tasks/mask_generation.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Any, Callable, Dict, Optional, Tuple
22
import torch
3-
from ..helpers.cache_helper import make_dynamic_cache
4-
from ..helpers.config_helper import update_config, check_hasattr, _pick
3+
from ..helpers.config_helper import update_config, check_hasattr
54

65
__TASK__ = "mask-generation"
76

@@ -31,7 +30,7 @@ def get_inputs(
3130
):
3231
"""
3332
Generates input for task ``mask-generation``.
34-
33+
3534
:param model: model to get the missing information
3635
:param config: configuration used to generate the model
3736
:param batch_size: batch size
@@ -46,10 +45,10 @@ def get_inputs(
4645
assert (
4746
"cls_cache" not in kwargs
4847
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
49-
5048

51-
# TODO(anyone): input_masks is weridly failing all the time with mismatch channels with Conv
52-
# or embedding_size. I guess maybe the model is too implicit on the input_masks shape.
49+
# TODO(anyone): input_masks is weridly failing all the time with mismatch channels
50+
# with Conv or embedding_size. I guess maybe the model is too implicit on the
51+
# input_masks shape.
5352

5453
shapes = {
5554
"pixel_values": {0: "batch", 2: "height", 3: "width"}, # 1: num_channels is static
@@ -64,9 +63,7 @@ def get_inputs(
6463
input_points=torch.randn(
6564
(batch_size, 1, 10, 2), dtype=torch.float32
6665
), # 10 points per image
67-
input_boxes=torch.randn(
68-
(batch_size, 1, 4), dtype=torch.float32
69-
), # 1 box per image
66+
input_boxes=torch.randn((batch_size, 1, 4), dtype=torch.float32), # 1 box per image
7067
# input_masks=torch.randn(
7168
# (batch_size, 1, height, width), dtype=torch.float32
7269
# ), # mask for the image

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@ def forward(
11841184
return attn_output, attn_weights, past_key_value
11851185
return attn_output, attn_weights
11861186

1187+
11871188
class patched_SamMaskDecoder(torch.nn.Module):
11881189
_PATCHES_ = ["forward"]
11891190
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
@@ -1223,10 +1224,11 @@ def forward(
12231224
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
12241225

12251226
# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1226-
# torch.any is needed to avoid data-dependent control flow
1227+
# torch.any is needed to avoid data-dependent control flow
12271228
# with sparse_prompt_embeddings.sum().item() != 0
12281229
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
12291230
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
1231+
12301232
def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
12311233
return output_tokens.clone()
12321234

@@ -1242,7 +1244,9 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
12421244
# Expand per-image data in batch direction to be per-point
12431245
image_embeddings = image_embeddings + dense_prompt_embeddings
12441246
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
1245-
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
1247+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(
1248+
point_batch_size, 0
1249+
)
12461250

12471251
# Run the transformer, image_positional_embedding are consumed
12481252
point_embedding, image_embeddings, attentions = self.transformer(
@@ -1272,8 +1276,12 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
12721276
hyper_in = torch.stack(hyper_in_list, dim=2)
12731277

12741278
_, num_channels, height, width = upscaled_embedding.shape
1275-
upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
1276-
masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
1279+
upscaled_embedding = upscaled_embedding.reshape(
1280+
batch_size, point_batch_size, num_channels, height * width
1281+
)
1282+
masks = (hyper_in @ upscaled_embedding).reshape(
1283+
batch_size, point_batch_size, -1, height, width
1284+
)
12771285

12781286
# Generate mask quality predictions
12791287
iou_pred = self.iou_prediction_head(iou_token_out)
@@ -1289,8 +1297,8 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
12891297
outputs = (masks, iou_pred)
12901298

12911299
if output_attentions:
1292-
outputs = outputs + (attentions,)
1300+
outputs = outputs + (attentions,) # noqa: RUF005
12931301
else:
1294-
outputs = outputs + (None,)
1302+
outputs = outputs + (None,) # noqa: RUF005
12951303

1296-
return outputs
1304+
return outputs

0 commit comments

Comments
 (0)