Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.0
+++++

* :pr:`280`: fixes patches for sdpa_attention_forward for different version of transformers
* :pr:`278`: implements ``onnx_generate_with_genai``
* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
* :pr:`276`: implements ``onnx_generate`` which implements method generate for an onnx model,
Expand Down
13 changes: 10 additions & 3 deletions _doc/recipes/plot_dynamic_shapes_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions

bsize, nheads, slen, dim = 2, 1, 30, 96

Expand All @@ -41,7 +42,11 @@
# %%
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`
# produces the corresponding dynamic shapes assuming they are all dynamic.
ds = all_dynamic_shapes_from_inputs(inputs)
# ``register_additional_serialization_functions(patch_transformers=True)`` registers
# function letting pytorch know how to serialize, deserialize the class DynamicCache.

with register_additional_serialization_functions(patch_transformers=True):
ds = all_dynamic_shapes_from_inputs(inputs)
pprint.pprint(ds)

# %%
Expand Down Expand Up @@ -88,7 +93,8 @@ def flatten_unflatten_like_dynamic_shapes(obj):
return list(subtrees)
raise ValueError(
f"Unable to interpret spec type {spec.type} "
f"(type is {type(spec.type)}, context is {spec.context})."
f"(type is {type(spec.type)}, context is {spec.context}), "
f"obj type is {type(obj)}."
)


Expand All @@ -109,7 +115,8 @@ def fix_dynamic_shapes(inputs, dynamic_shapes):
return _align(flat_unflat_inputs, dynamic_shapes)


fixed_ds = fix_dynamic_shapes(inputs, ds2)
with register_additional_serialization_functions(patch_transformers=True):
fixed_ds = fix_dynamic_shapes(inputs, ds2)
pprint.pprint(fixed_ds)

# %%
Expand Down
36 changes: 32 additions & 4 deletions _doc/recipes/plot_export_dim1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
Expand All @@ -29,21 +31,28 @@ def forward(self, x, y, z):
DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep.graph)
print(ep)

# %%
# Same model, a dynamic dimension = 1
# +++++++++++++++++++++++++++++++++++


z = z[:1]

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

try:
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep.graph)
print(ep)
except Exception as e:
print("ERROR", e)

Expand All @@ -54,14 +63,33 @@ def forward(self, x, y, z):
# Same model, a dynamic dimension = 1 and backed_size_oblivious=True
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

try:
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep.graph)
print(ep)
except RuntimeError as e:
print("ERROR", e)


# %%
# Final try with patches...
# +++++++++++++++++++++++++

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

with torch_export_patches(patch_torch=1):
try:
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep)
except RuntimeError as e:
print("ERROR", e)

# %%
# It worked.
# It is difficult to find the good option. It is possible on a simple model
# but sometimes impossible on a bigger model mixing different shapes.

doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")
34 changes: 34 additions & 0 deletions _unittests/ut_export/test_shape_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,40 @@ def test_make_fake_with_dynamic_dimensions_tensor(self):
self.assertEqual(reshaped.shape[3], 96)
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])

@requires_transformers("4.55")
@requires_torch("2.9")
def test_make_fake_with_dynamic_dimensions_two_tensors(self):
res = make_fake_with_dynamic_dimensions(
(
torch.rand((2, 32, 30, 96), dtype=torch.float16),
torch.rand((2, 32, 30, 96), dtype=torch.float16),
),
({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
)
reshaped = res[0][0]
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
self.assertEqual(reshaped.shape[1], 32)
self.assertEqual(reshaped.shape[3], 96)
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])
self.assertEqual(str(res[0][0].shape), str(res[0][1].shape))
sh1 = res[0][0].shape
sh2 = res[0][1].shape
self.assertEqual(sh1[0], sh2[0])
self.assertEqual(sh1[1], sh2[1])
self.assertEqual(sh1[2], sh2[2])
self.assertEqual(sh1[3], sh2[3])

def test_make_fake_with_dynamic_dimensions_attention(self):
query = torch.rand((1, 2, 1, 96), dtype=torch.float32)
key = torch.rand((1, 2, 4, 96), dtype=torch.float32)
value = torch.rand((1, 2, 4, 96), dtype=torch.float32)
ds = ({0: "batch", 2: "seq1"}, {0: "batch", 2: "seq2"}, {0: "batch", 2: "seq2"})
fake_inputs, _ = make_fake_with_dynamic_dimensions((query, key, value), ds)
self.assertEqual(fake_inputs[1].shape, fake_inputs[2].shape)
self.assertEqual(fake_inputs[0].shape[0], fake_inputs[1].shape[0])
self.assertEqual(fake_inputs[0].shape[0], fake_inputs[2].shape[0])

@requires_transformers("4.55")
@requires_torch("2.9")
def test_make_fake_with_dynamic_dimensions_whole(self):
Expand Down
66 changes: 33 additions & 33 deletions _unittests/ut_helpers/test_fake_tensor_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,72 @@
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
from onnx_diagnostic.helpers import flatten_object
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, fake_reshape
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake, FakeTensorContext


class TestMakeTensorHelper(ExtTestCase):

@requires_transformers("4.55")
def test_fake_inputs(self):
inputs, _ = make_fake(
dict(
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
past_key_values=make_dynamic_cache(
[
(
torch.rand((2, 32, 30, 96), dtype=torch.float16),
torch.rand((2, 32, 30, 96), dtype=torch.float16),
),
(
torch.rand((2, 32, 30, 96), dtype=torch.float16),
torch.rand((2, 32, 30, 96), dtype=torch.float16),
),
]
),
)
)
flat = flatten_object(inputs, drop_keys=True)
for t in flat:
self.assertIsInstance(t, torch.Tensor)
assert all(
isinstance(s, torch.SymInt) for s in t.shape
), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}"

def test_fake_reshape_generic(self):
t = torch.zeros((2, 3, 4, 5), dtype=torch.float32)
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
self.assertEqual(reshaped.shape[1], 3)
self.assertEqual(reshaped.shape[3], 5)

def test_fake_reshape_dim_1(self):
t = torch.zeros((1, 3, 4, 5), dtype=torch.float32)
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
self.assertEqual(reshaped.shape[1], 3)
self.assertEqual(reshaped.shape[3], 5)

def test_fake_reshape_dim_0(self):
t = torch.zeros((0, 3, 4, 5), dtype=torch.float32)
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
self.assertEqual(reshaped.shape[1], 3)
self.assertEqual(reshaped.shape[3], 5)

def test_fake_reshape_different(self):
t = torch.zeros((2, 3, 2, 5), dtype=torch.float32)
reshaped = fake_reshape(t, {0: "batch", 2: "seq_length"})
reshaped = FakeTensorContext().fake_reshape(t, {0: "batch", 2: "seq_length"})
self.assertIsInstance(reshaped.shape[0], torch.SymInt)
self.assertIsInstance(reshaped.shape[2], torch.SymInt)
self.assertEqual(reshaped.shape[1], 3)
self.assertEqual(reshaped.shape[3], 5)
self.assertNotEqual(reshaped.shape[0], reshaped.shape[2])

@requires_transformers("4.55")
def test_fake_inputs(self):
inputs, _ = make_fake(
dict(
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
past_key_values=make_dynamic_cache(
[
(
torch.rand((2, 32, 30, 96), dtype=torch.float16),
torch.rand((2, 32, 30, 96), dtype=torch.float16),
),
(
torch.rand((2, 32, 30, 96), dtype=torch.float16),
torch.rand((2, 32, 30, 96), dtype=torch.float16),
),
]
),
)
)
flat = flatten_object(inputs, drop_keys=True)
for t in flat:
self.assertIsInstance(t, torch.Tensor)
assert all(
isinstance(s, torch.SymInt) for s in t.shape
), f"Wrong type {[type(s) for s in t.shape]} in {t.shape}"


if __name__ == "__main__":
unittest.main(verbosity=2)
14 changes: 11 additions & 3 deletions _unittests/ut_tasks/test_tasks_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_text_generation_phi_3_mini_128k_instruct(self):

@hide_stdout()
@requires_transformers("4.53")
@requires_torch("2.7.99")
@requires_torch("2.8.99") # check_guards not supported
def test_text_generation_tiny_llm(self):
mid = "arnir0/Tiny-LLM"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
Expand All @@ -58,12 +58,20 @@ def test_text_generation_tiny_llm(self):
expected = model(**torch_deepcopy(inputs))
model(**data["inputs2"])
fake = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes=ds)[0]
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
with torch_export_patches(patch_transformers=True, verbose=1, patch_torch=False):
ep = torch.export.export(
model, (), kwargs=fake, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
# print(ep)
got = ep.module()(**inputs_copied)
rem = []
for node in ep.graph.nodes:
if "_assert" in str(node.target):
rem.append(node)
for node in rem:
ep.graph.erase_node(node)
ep.graph.lint()
mod = ep.module(check_guards=False)
got = mod(**inputs_copied)
self.assertEqualAny(expected.past_key_values, got.past_key_values)
self.assertEqualArray(expected.logits, got.logits)

Expand Down
Loading
Loading