Skip to content

Commit b893fc9

Browse files
committed
try
1 parent 63576f3 commit b893fc9

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,18 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange):
204204
ep = torch.export.export(Model(), inputs, dynamic_shapes=ds)
205205
self.assertEqualArray(causal_mask, ep.module()(*inputs))
206206

207+
@requires_torch("2.7")
208+
def test_export_unsqueeze(self):
209+
class Model(torch.nn.Module):
210+
def forward(self, x):
211+
return x.unsqueeze(0).unsqueeze(2).unsqueeze(3)
212+
213+
x = torch.tensor([7.0, 8.0])
214+
Model()(x)
215+
DYN = torch.export.Dim.DYNAMIC
216+
ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},))
217+
self.assertEqualArray(Model()(x), ep.module()(x))
218+
207219

208220
if __name__ == "__main__":
209221
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,30 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
3131
def vector_mask_function(
3232
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
3333
):
34-
assert len(args) == len(dimensions) == len(udimensions) + 1, (
34+
assert len(args) == len(dimensions) == len(udimensions), (
3535
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
36-
f"and udimensions={udimensions}"
36+
f"and udimensions={udimensions}."
3737
)
38+
assert len(indices) == len(args), (
39+
f"Mismatch between args={string_type(args)} and indices={indices}, "
40+
f"they should have the same length."
41+
)
42+
for a in args:
43+
assert (
44+
a.ndim == 1
45+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
46+
torch._check(a.shape[0] > 0)
47+
3848
# new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
3949
new_args = [
40-
a.reshape((-1,)).unsqueeze(shape[0]).unsqueeze(shape[1]).unsqueeze(shape[2])
41-
for a, shape in zip(args, udimensions)
50+
a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
51+
for a, dims in zip(args, udimensions)
4252
]
4353
max_shape = tuple(args[i].shape[0] for i in indices)
54+
if is_torchdynamo_exporting():
55+
for a in args:
56+
# The exporter should export with a dimension > 1 to make sure it is dynamic.
57+
torch._check(a.shape[0] > 1)
4458
expanded_args = [a.expand(max_shape) for a in new_args]
4559
return mask_function(*expanded_args)
4660

0 commit comments

Comments
 (0)