Skip to content

Commit 2250887

Browse files
committed
fix inputs
1 parent f86d081 commit 2250887

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
text_to_image,
1515
text2text_generation,
1616
zero_shot_image_classification,
17+
mask_generation,
1718
)
1819

1920
__TASKS__ = [
@@ -31,6 +32,7 @@
3132
text_to_image,
3233
text2text_generation,
3334
zero_shot_image_classification,
35+
mask_generation,
3436
]
3537

3638

onnx_diagnostic/tasks/mask_generation.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,14 @@ def get_inputs(
4848
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
4949

5050

51+
# TODO(anyone): input_masks is weridly failing all the time with mismatch channels with Conv
52+
# or embedding_size. I guess maybe the model is too implicit on the input_masks shape.
53+
5154
shapes = {
5255
"pixel_values": {0: "batch", 2: "height", 3: "width"}, # 1: num_channels is static
5356
"input_points": {0: "batch", 1: "point_batch_size", 2: "nb_points_per_image"},
5457
"input_boxes": {0: "batch", 1: "point_batch_size"},
55-
"input_masks": {0: "batch", 1: "height", 2: "width"},
58+
# "input_masks": {0: "batch", 2: "height", 3: "width"},
5659
}
5760
inputs = dict(
5861
pixel_values=torch.randn(
@@ -64,10 +67,11 @@ def get_inputs(
6467
input_boxes=torch.randn(
6568
(batch_size, 1, 4), dtype=torch.float32
6669
), # 1 box per image
67-
input_masks=torch.randn(
68-
(batch_size, num_channels, height, width), dtype=torch.float32
69-
), # mask for the image
70+
# input_masks=torch.randn(
71+
# (batch_size, 1, height, width), dtype=torch.float32
72+
# ), # mask for the image
7073
)
74+
7175
res = dict(inputs=inputs, dynamic_shapes=shapes)
7276
if add_second_input:
7377
assert (
@@ -82,6 +86,7 @@ def get_inputs(
8286
num_channels=num_channels,
8387
output_channels=output_channels,
8488
window_size=window_size,
89+
add_second_input=False,
8590
**kwargs,
8691
)["inputs"]
8792
return res
@@ -128,7 +133,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
128133
width=1024 if config is None else config.vision_config.image_size,
129134
height=1024 if config is None else config.vision_config.image_size,
130135
num_channels=3 if config is None else config.vision_config.num_channels,
131-
output_channels=256 if config is None else config.vision_config.num_channels,
136+
output_channels=256 if config is None else config.vision_config.output_channels,
132137
window_size=14 if config is None else config.vision_config.window_size,
133138
)
134139
return kwargs, get_inputs

0 commit comments

Comments
 (0)