Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.7.10
++++++

* :pr:`218`: patches used sdpa_mask_recent_torch used from _vmap_for_bhqkv

0.7.9
+++++

Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Getting started

git clone https://github.com/sdpython/onnx-diagnostic.git
cd onnx-diagnostic
pip install -e .
pip install -e . -v

or

Expand Down
2 changes: 1 addition & 1 deletion _doc/examples/plot_dump_intermediate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
# Let's create the ONNX model.

ep = torch.export.export(model, inputs, dynamic_shapes=ds)
epo = torch.onnx.export(ep, dynamo=True)
epo = torch.onnx.export(ep)
epo.optimize()
epo.save("plot_dump_intermediate_results.onnx")

Expand Down
2 changes: 1 addition & 1 deletion _doc/examples/plot_export_tiny_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@

with torch_export_patches(patch_transformers=True):
epo = torch.onnx.export(
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes
)

# %%
Expand Down
17 changes: 1 addition & 16 deletions _unittests/ut_export/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
ignore_warnings,
requires_onnxscript,
)
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting

try:
Expand Down Expand Up @@ -62,7 +61,7 @@ def test_dummy_loop(self):

@hide_stdout()
@ignore_warnings(UserWarning)
@requires_onnxscript("0.5")
@requires_onnxscript("0.7")
def test_export_loop_onnxscript(self):
class Model(torch.nn.Module):
def forward(self, images, position):
Expand All @@ -75,19 +74,6 @@ def forward(self, images, position):
y = torch.arange(5, dtype=torch.int64) + 1
expected = model(x, y)

name = self.get_dump_file("test_export_loop_onnxscript.onnx")
torch.onnx.export(
model,
(x, y),
name,
dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
dynamo=False,
)
ref = ExtendedReferenceEvaluator(name)
feeds = dict(images=x.numpy(), position=y.numpy())
got = ref.run(None, feeds)[0]
self.assertEqualArray(expected, got)

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(
model,
Expand All @@ -103,7 +89,6 @@ def forward(self, images, position):
(x, y),
name2,
dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
dynamo=True,
fallback=False,
)
import onnxruntime
Expand Down
96 changes: 95 additions & 1 deletion _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from typing import Callable
import torch
import transformers
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
Expand All @@ -19,6 +20,13 @@
)
from onnx_diagnostic.torch_export_patches import torch_export_patches

try:
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched__vmap_for_bhqkv,
)
except ImportError:
patched__vmap_for_bhqkv = None


class TestCacheHelpers(ExtTestCase):
def test_string_type(self):
Expand Down Expand Up @@ -69,7 +77,7 @@ def test_replace_by(self):
)

DYN = torch.export.Dim.DYNAMIC
nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes(
_nargs, _nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes(
None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes
)
self.assertEqual(dynamic_shapes, nds)
Expand Down Expand Up @@ -254,6 +262,92 @@ def test_unflatten_flatten_hybrid_cache(self):
self.string_type(unflat, with_shape=True),
)

@unittest.skipIf(patched__vmap_for_bhqkv is None, "transformers too old")
def test_cache_update_padding_mask_function_vmap(self):
def causal_mask_function(
batch_idx: int, head_idx: int, q_idx: int, kv_idx: int
) -> bool:
return kv_idx <= q_idx

def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return padding_mask[batch_idx, kv_idx]

return inner_mask

def and_masks(*mask_functions: list[Callable]) -> Callable:
if not all(callable(arg) for arg in mask_functions):
raise RuntimeError(
f"All inputs should be callable mask_functions: {mask_functions}"
)

def and_mask(batch_idx, head_idx, q_idx, kv_idx):
result = q_idx.new_ones((), dtype=torch.bool)
for mask in mask_functions:
result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(
result.device
)
return result

return and_mask

def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
dimensions = [(None, None, None, 0), (None, None, 0, None)]
if bh_indices:
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
for dims in dimensions:
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
return mask_function

class Model(torch.nn.Module):
def forward(self, x, mask):
mask_function = and_masks(causal_mask_function, padding_mask_function(mask))
batch_arange = torch.arange(x.shape[0])
head_arange = torch.arange(x.shape[3])
kv_arange = torch.arange(x.shape[1])
cache_position = torch.arange(x.shape[2])
f = patched__vmap_for_bhqkv(mask_function)
causal_mask = f(batch_arange, head_arange, cache_position, kv_arange)
return x + causal_mask.to(x.dtype)

inputs = {
"x": torch.rand((4, 4, 4, 4), dtype=torch.float32),
"mask": torch.ones((4, 4), dtype=torch.int64),
}
model = Model()
expected = model(**inputs)
self.assertNotEmpty(expected)
DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(
model,
(),
kwargs=inputs,
dynamic_shapes={"x": {0: DYN, 1: DYN, 2: DYN, 3: DYN}, "mask": {0: DYN, 1: DYN}},
)
self.assertNotEmpty(ep)

def test_simple_indices(self):
class Model(torch.nn.Module):
def forward(self, x, i, j):
return x[i, j]

inputs = (
torch.rand((4, 4), dtype=torch.float32),
torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64),
torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64),
)
model = Model()
expected = model(*inputs)
self.assertEqual(expected.shape, (4, 4, 4, 4))
DYN = torch.export.Dim.DYNAMIC
sh = {0: DYN, 1: DYN, 2: DYN, 3: DYN}
ep = torch.export.export(
model,
inputs,
dynamic_shapes=({0: DYN, 1: DYN}, sh, sh),
)
self.assertNotEmpty(ep)


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion _unittests/ut_helpers/test_ort_session_tinyllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_check_allruntimes_on_tiny_llm(self):
proto = to_onnx(model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds)
else:
proto = torch.onnx.export(
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds, dynamo=True
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds
).model_proto

self.dump_onnx("test_check_allruntimes_on_tiny_llm.onnx", proto)
Expand Down
19 changes: 16 additions & 3 deletions _unittests/ut_torch_export_patches/test_dynamic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import unittest
from typing import Any, Dict, List, Tuple
import torch

try:
import transformers.masking_utils as masking_utils
except ImportError:
masking_utils = None
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
ignore_warnings,
Expand All @@ -14,7 +19,9 @@
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
torch_export_patches,
)
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers


class TestOnnxExportErrors(ExtTestCase):
Expand Down Expand Up @@ -305,7 +312,7 @@ def test_phi2_export_module(self):
model,
(),
kwargs=inputs,
dynamic_shapes=dyn_shapes,
dynamic_shapes=use_dyn_not_str(dyn_shapes),
strict=False, # True works but then the it fails during the execution
)
# ep = ep.run_decompositions()
Expand All @@ -319,6 +326,7 @@ def test_phi2_export_module(self):

@ignore_warnings(UserWarning)
@requires_torch("2.9")
@hide_stdout()
def test_phi2_export_interpreter(self):
data = get_untrained_model_with_inputs("microsoft/phi-2")
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
Expand All @@ -338,12 +346,17 @@ def test_phi2_export_interpreter(self):
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)

with torch_export_patches(patch_transformers=True):
with torch_export_patches(patch_transformers=True, verbose=1):
if masking_utils is not None:
self.assertEqual(
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
patch_transformers.patched_sdpa_mask_recent_torch,
)
ep = torch.export.export(
model,
(),
kwargs=inputs,
dynamic_shapes=dyn_shapes,
dynamic_shapes=use_dyn_not_str(dyn_shapes),
strict=False, # True works but then the it fails during the execution
)
# ep = ep.run_decompositions()
Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_torch_export_patches/test_patch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,8 @@ def loop_body_1(z, iv, x, y):
rewritten_expected2 = RewrittenModel2()(x, y)
self.assertEqualArray(expected, rewritten_expected2)

if not has_torch("2.9"):
raise unittest.SkipTest("skipped export, torch must be >= 2.9")
if not has_torch("2.10"):
raise unittest.SkipTest("skipped export, torch must be >= 2.10")

torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False)
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds, strict=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def forward(self, cache):
def test_base_model_output_unflatten_flatten(self):
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
with torch_export_patches(patch_transformers=True):
flat, _spec = torch.utils._pytree.tree_flatten(bo)
_flat, _spec = torch.utils._pytree.tree_flatten(bo)
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
self.assertIsInstance(unflat, list)
self.assertEqual("#1[T1r3]", self.string_type(unflat))
Expand Down
11 changes: 6 additions & 5 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_vmap(self):
got = patched_vmap(f)(x, y)
self.assertEqualArray(expected, got)

@requires_torch("2.9")
@requires_torch("2.10")
def test_export_vmap(self):
class Model(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -206,10 +206,11 @@ def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callabl

class Model(torch.nn.Module):
def forward(self, batch_arange, head_arange, cache_position, kv_arange):
with TransformGetItemToIndex():
causal_mask2 = _vmap_for_bhqkv2(mask_function)(
batch_arange, head_arange, cache_position, kv_arange
)
# with TransformGetItemToIndex():
# This context as ignored in 2.8 and not any more in 2.9.
causal_mask2 = _vmap_for_bhqkv2(mask_function)(
batch_arange, head_arange, cache_position, kv_arange
)
return causal_mask2

inputs = batch_arange, head_arange, cache_position, kv_arange
Expand Down
33 changes: 29 additions & 4 deletions _unittests/ut_torch_models/test_llm_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,51 @@
)
from onnx_diagnostic.torch_models.llms import get_phi2
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestLlmPhi(ExtTestCase):
def test_get_phi2(self):
data = get_phi2(num_hidden_layers=2)
data = get_phi2(num_hidden_layers=2, batch_size=2)
model, inputs = data["model"], data["inputs"]
self.assertIn("DynamicCache", string_type(inputs))
model(**inputs)

@ignore_warnings(UserWarning)
@requires_transformers("4.54")
@requires_torch("2.9.99")
def test_export_phi2_1(self):
def test_export_phi2_1_batch_size_1(self):
# exporting vmap does not work
data = get_phi2(num_hidden_layers=2)
data = get_phi2(num_hidden_layers=2, batch_size=1)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
self.assertEqual(inputs["input_ids"].shape[0], 1)
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
with torch.fx.experimental._config.patch(
backed_size_oblivious=True
), torch_export_patches(patch_transformers=True):
ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
assert ep

@ignore_warnings(UserWarning)
@requires_transformers("4.54")
@requires_torch("2.9.99")
def test_export_phi2_1_batch_size_2(self):
# exporting vmap does not work
data = get_phi2(num_hidden_layers=2, batch_size=2)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
self.assertEqual(inputs["input_ids"].shape[0], 2)
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
with torch_export_patches(patch_transformers=True):
ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
assert ep


Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_tiny_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_tiny_llm_export_dynamic(self):
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
with torch_export_patches(patch_transformers=True):
with torch_export_patches(patch_transformers=True, verbose=1):
ep = torch.export.export(
model,
(),
Expand Down
Loading
Loading