Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions _unittests/ut_tasks/test_tasks_mask_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTasksMaskGeneration(ExtTestCase):
@hide_stdout()
@requires_transformers("4.53")
@requires_torch("2.7.99")
def test_mask_generation(self):
mid = "fxmarty/sam-vit-tiny-random"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "mask-generation")
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**torch_deepcopy(inputs))
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 2 additions & 0 deletions onnx_diagnostic/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
text_to_image,
text2text_generation,
zero_shot_image_classification,
mask_generation,
)

__TASKS__ = [
Expand All @@ -31,6 +32,7 @@
text_to_image,
text2text_generation,
zero_shot_image_classification,
mask_generation,
]


Expand Down
139 changes: 139 additions & 0 deletions onnx_diagnostic/tasks/mask_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.config_helper import update_config, check_hasattr

__TASK__ = "mask-generation"


def reduce_model_config(config: Any) -> Dict[str, Any]:
"""Reduces a model size."""
kwargs: Dict[str, Any] = {}
if hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = min(config.num_hidden_layers, 2)
if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
update_config(config, kwargs)
return kwargs


def get_inputs(
model: torch.nn.Module,
config: Optional[Any],
batch_size: int,
width: int,
height: int,
num_channels: int,
output_channels: int,
window_size: int,
add_second_input: bool = True,
**kwargs, # unused
):
"""
Generates input for task ``mask-generation``.

:param model: model to get the missing information
:param config: configuration used to generate the model
:param batch_size: batch size
:param width: width of the image
:param height: height of the image
:param num_channels: number of channels in the image
:param output_channels: number of output channels
:param window_size: size of the window for the vision model
:return: dictionary with inputs and dynamic shapes

"""
assert (
"cls_cache" not in kwargs
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."

# TODO(anyone): input_masks is weirdly failing all the time with mismatch channels
# with Conv or embedding_size. I guess maybe the model is too implicit on the
# input_masks shape.

# TODO(titaiwang): modeling code specifically requires the height and width of inputs
# should be the same as the config.vision_config.image_size. Does that make sense?

shapes = {
"pixel_values": {0: "batch"}, # 1: num_channels is static
"input_points": {0: "batch", 1: "point_batch_size", 2: "nb_points_per_image"},
"input_boxes": {0: "batch", 1: "point_batch_size"},
# "input_masks": {0: "batch", 2: "height", 3: "width"},
}
inputs = dict(
pixel_values=torch.randn(
(batch_size, num_channels, height, width), dtype=torch.float32
).clamp(-1, 1),
input_points=torch.randn(
(batch_size, 2, 10, 2), dtype=torch.float32
), # 10 points per image
input_boxes=torch.randn((batch_size, 2, 4), dtype=torch.float32), # 1 box per image
# input_masks=torch.randn(
# (batch_size, 1, height, width), dtype=torch.float32
# ), # mask for the image
)

res = dict(inputs=inputs, dynamic_shapes=shapes)
if add_second_input:
assert (
add_second_input > 0
), f"Not implemented for add_second_input={add_second_input}."
res["inputs2"] = get_inputs(
model=model,
config=config,
batch_size=batch_size + 1,
width=width,
height=height,
num_channels=num_channels,
output_channels=output_channels,
window_size=window_size,
add_second_input=False,
**kwargs,
)["inputs"]
return res


def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"""
Inputs kwargs.

If the configuration is None, the function selects typical dimensions.
"""
if config is not None:
# generates mask as outputs
if hasattr(config, "mask_decoder_config"):
check_hasattr(
config.mask_decoder_config,
"hidden_size",
"iou_head_hidden_dim",
"iou_head_depth",
"num_hidden_layers",
"num_multimask_outputs",
)
if hasattr(config, "prompt_encoder_config"):
check_hasattr(
config.prompt_encoder_config,
"hidden_size",
"image_embedding_size",
"image_size",
"mask_input_channels",
)
if hasattr(config, "vision_config"):
check_hasattr(
config.vision_config,
"image_size",
"hidden_size",
"intermediate_size",
"num_hidden_layers",
"output_channels",
"num_channels",
"window_size",
)
kwargs = dict(
batch_size=2,
width=1024 if config is None else config.vision_config.image_size,
height=1024 if config is None else config.vision_config.image_size,
num_channels=3 if config is None else config.vision_config.num_channels,
output_channels=256 if config is None else config.vision_config.output_channels,
window_size=14 if config is None else config.vision_config.window_size,
)
return kwargs, get_inputs
121 changes: 121 additions & 0 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,3 +1183,124 @@ def forward(
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights


class patched_SamMaskDecoder(torch.nn.Module):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder

def forward(
self,
image_embeddings: torch.Tensor,
image_positional_embeddings: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
attention_similarity: Optional[torch.Tensor] = None,
target_embedding: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.

Args:
image_embeddings (`torch.Tensor`):
the embeddings from the image encoder
image_positional_embedding (`torch.Tensor`):
positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (`torch.Tensor`):
The embeddings of the points and boxes
dense_prompt_embeddings (`torch.Tensor`):
the embeddings of the mask inputs
multimask_output (bool):
Whether to return multiple masks or a single mask.
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)

# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
# torch.any is needed to avoid data-dependent control flow
# with sparse_prompt_embeddings.sum().item() != 0
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)

def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
return output_tokens.clone()

tokens = torch.cond(
torch.any(sparse_prompt_embeddings != 0),
sparse_prompt_embeddings_is_not_empty,
sparse_prompt_embeddings_is_empty,
[output_tokens, sparse_prompt_embeddings],
)

point_embeddings = tokens.to(self.iou_token.weight.dtype)

# Expand per-image data in batch direction to be per-point
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat_interleave(
point_batch_size, 0
)

# Run the transformer, image_positional_embedding are consumed
point_embedding, image_embeddings, attentions = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
iou_token_out = torch.select(point_embedding, dim=2, index=0)
mask_tokens_out = torch.narrow(
point_embedding, dim=2, start=1, length=self.num_mask_tokens
)

# Upscale mask embeddings and predict masks using the mask tokens
image_embeddings = image_embeddings.transpose(2, 3).reshape(
batch_size * point_batch_size, num_channels, height, width
)

upscaled_embedding = self.upscale_conv1(image_embeddings)
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))

hyper_in_list = []
for i in range(self.num_mask_tokens):
current_mlp = self.output_hypernetworks_mlps[i]
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
hyper_in = torch.stack(hyper_in_list, dim=2)

_, num_channels, height, width = upscaled_embedding.shape
upscaled_embedding = upscaled_embedding.reshape(
batch_size, point_batch_size, num_channels, height * width
)
masks = (hyper_in @ upscaled_embedding).reshape(
batch_size, point_batch_size, -1, height, width
)

# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)

# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]

outputs = (masks, iou_pred)

if output_attentions:
outputs = outputs + (attentions,) # noqa: RUF005
else:
outputs = outputs + (None,) # noqa: RUF005

return outputs
Loading