Skip to content

Commit 61375dc

Browse files
committed
fix
1 parent 88efcfe commit 61375dc

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ def vector_mask_function(
4545
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
4646
torch._check(a.shape[0] > 0)
4747

48-
# new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
49-
new_args = [
50-
a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
51-
for a, dims in zip(args, udimensions)
52-
]
48+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
49+
# new_args = [
50+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
51+
# for a, dims in zip(args, udimensions)
52+
# ]
5353
max_shape = tuple(args[i].shape[0] for i in indices)
5454
# if is_torchdynamo_exporting():
5555
# for a in args:

0 commit comments

Comments
 (0)