Skip to content

Commit b3cb1f2

Browse files
committed
fix patch
1 parent 245be95 commit b3cb1f2

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

_unittests/ut_tasks/test_tasks_mask_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_mask_generation(self):
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2424
model(**torch_deepcopy(inputs))
2525
model(**data["inputs2"])
26-
with torch_export_patches(patch_transformers=True, verbose=10):
26+
with torch_export_patches(patch_transformers=True, verbose=1):
2727
torch.export.export(
2828
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2929
)

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,14 +1300,19 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
13001300
)
13011301

13021302
# Run the transformer, image_positional_embedding are consumed
1303-
point_embedding, image_embeddings, attentions = self.transformer(
1303+
torch._check(point_embeddings.shape[0] != 0)
1304+
torch._check(point_embeddings.shape[1] != 0)
1305+
torch._check(point_embeddings.shape[2] != 0)
1306+
torch._check(point_embeddings.shape[3] != 0)
1307+
embeddings_attentions = self.transformer(
13041308
point_embeddings=point_embeddings,
13051309
image_embeddings=image_embeddings,
13061310
image_positional_embeddings=image_positional_embeddings,
13071311
attention_similarity=attention_similarity,
13081312
target_embedding=target_embedding,
13091313
output_attentions=output_attentions,
13101314
)
1315+
point_embedding, image_embeddings = embeddings_attentions[:2]
13111316
iou_token_out = torch.select(point_embedding, dim=2, index=0)
13121317
mask_tokens_out = torch.narrow(
13131318
point_embedding, dim=2, start=1, length=self.num_mask_tokens
@@ -1349,9 +1354,12 @@ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
13491354

13501355
outputs = (masks, iou_pred)
13511356

1352-
if output_attentions:
1353-
outputs = outputs + (attentions,) # noqa: RUF005
1357+
if len(embeddings_attentions) == 2:
1358+
# transformers==4.54
1359+
return outputs
1360+
1361+
if output_attentions and len(embeddings_attentions) > 2:
1362+
outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
13541363
else:
13551364
outputs = outputs + (None,) # noqa: RUF005
1356-
13571365
return outputs

0 commit comments

Comments
 (0)