Skip to content

Commit 54e8373

Browse files
authored
Fixes text2text generation, add summarization (#104)
* improve task text2text * fix issues * add summarization * more robust * insert clone * fix a few things * assert * fix issues * change * fix * fix
1 parent c3da823 commit 54e8373

File tree

18 files changed

+626
-33
lines changed

18 files changed

+626
-33
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`104`: add summarization task, add rewrite to command line validate
8+
* :pr:`101`: first draft to rewrite loops
79
* :pr:`100`: implements a context to automatically rewrite methods or function with control flows
810
* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward``
911
* :pr:`95`: fixzq Scan implementation for ``OnnxruntimeEvaluator``

_doc/api/tasks/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Or:
4343
mixture_of_expert
4444
object_detection
4545
sentence_similarity
46+
summarization
4647
text_classification
4748
text_generation
4849
text2text_generation

_doc/api/tasks/summarization.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.summarization
3+
===================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.summarization
6+
:members:
7+
:no-undoc-members:

_unittests/ut_tasks/test_tasks.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
has_transformers,
7+
requires_transformers,
8+
)
49
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
510
from onnx_diagnostic.torch_export_patches import torch_export_patches
611
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
@@ -123,7 +128,7 @@ def test_fill_mask(self):
123128
)
124129

125130
@hide_stdout()
126-
def test_feature_extraction(self):
131+
def test_feature_extraction_bart_base(self):
127132
mid = "facebook/bart-base"
128133
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
129134
self.assertEqual(data["task"], "feature-extraction")
@@ -136,6 +141,35 @@ def test_feature_extraction(self):
136141
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
137142
)
138143

144+
@hide_stdout()
145+
def test_feature_extraction_tiny_bart(self):
146+
mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
147+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
148+
self.assertEqual(data["task"], "text2text-generation")
149+
self.assertIn((data["size"], data["n_weights"]), [(3243392, 810848)])
150+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
151+
model(**inputs)
152+
model(**data["inputs2"])
153+
with torch_export_patches(patch_transformers=True, verbose=10):
154+
torch.export.export(
155+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
156+
)
157+
158+
@requires_transformers("4.51.999")
159+
@hide_stdout()
160+
def test_summarization(self):
161+
mid = "facebook/bart-large-cnn"
162+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
163+
self.assertEqual(data["task"], "summarization")
164+
self.assertIn((data["size"], data["n_weights"]), [(1625161728, 406290432)])
165+
model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"]
166+
model(**inputs)
167+
model(**data["inputs2"])
168+
# with torch_export_patches(patch_transformers=True, verbose=10):
169+
# torch.export.export(
170+
# model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
171+
# )
172+
139173
@hide_stdout()
140174
def test_text_classification(self):
141175
mid = "Intel/bert-base-uncased-mrpc"

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,30 @@ def forward(self, x, y):
153153
self.assertEqualAny(expected, ep.module()(x, y))
154154
self.assertEqualAny(expected_, ep.module()(-x, y))
155155

156+
def test_check_syntax_assign_noelse(self):
157+
158+
class Model(torch.nn.Module):
159+
def forward(self, x, y):
160+
161+
def branch_cond_then_1(x):
162+
x = torch.abs(x) + 1
163+
return x
164+
165+
def branch_cond_else_1(x):
166+
return x.clone()
167+
168+
x = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x])
169+
return x + y
170+
171+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
172+
expected, expected_ = Model()(x, y), Model()(-x, y)
173+
DYN = torch.export.Dim.DYNAMIC
174+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
175+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
176+
self.assertIn("cond", [str(getattr(n, "target", "?")) for n in ep.graph.nodes])
177+
self.assertEqualAny(expected, ep.module()(x, y))
178+
self.assertEqualAny(expected_, ep.module()(-x, y))
179+
156180
def test_rewrite_test_in_forward_assign_noelse(self):
157181

158182
class Model(torch.nn.Module):
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.torch_models.hghub.hub_data import code_needing_rewriting
4+
5+
6+
class TestHuggingFaceHubModelRewrite(ExtTestCase):
7+
8+
def test_code_needing_rewriting(self):
9+
self.assertEqual(1, len(code_needing_rewriting("BartForConditionalGeneration")))
10+
11+
12+
if __name__ == "__main__":
13+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
validate_model,
1616
filter_inputs,
1717
run_ort_fusion,
18+
empty,
1819
)
1920
from onnx_diagnostic.tasks import supported_tasks
2021

@@ -32,6 +33,9 @@ def test_get_inputs_for_task(self):
3233
self.assertIn("dynamic_shapes", data)
3334
copy.deepcopy(data["inputs"])
3435

36+
def test_empty(self):
37+
self.assertFalse(empty("float16"))
38+
3539
@hide_stdout()
3640
def test_validate_model(self):
3741
mid = "arnir0/Tiny-LLM"

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,17 @@ def get_parser_validate() -> ArgumentParser:
309309
help="catches exception, report them in the summary",
310310
)
311311
parser.add_argument(
312-
"-p",
313312
"--patch",
314313
default=True,
315314
action=BooleanOptionalAction,
316315
help="applies patches before exporting",
317316
)
317+
parser.add_argument(
318+
"--rewrite",
319+
default=True,
320+
action=BooleanOptionalAction,
321+
help="applies rewrite before exporting",
322+
)
318323
parser.add_argument(
319324
"--stop-if-static",
320325
default=0,
@@ -411,6 +416,7 @@ def _cmd_validate(argv: List[Any]):
411416
dtype=args.dtype,
412417
device=args.device,
413418
patch=args.patch,
419+
rewrite=args.rewrite,
414420
stop_if_static=args.stop_if_static,
415421
optimization=args.opt,
416422
exporter=args.export,

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,9 +689,22 @@ def forward(self, input_ids):
689689
raise NotImplementedError(f"cls_name={cls_name}")
690690

691691

692-
def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
692+
def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
693693
"""Applies torch.to if applicable. Goes recursively."""
694-
if isinstance(value, (torch.nn.Module, torch.Tensor)):
694+
if isinstance(value, (torch.nn.Module, torch.Tensor)) and value.__class__.__name__ not in {
695+
"DynamicCache",
696+
"EncoderDecoderCache",
697+
}:
698+
if (
699+
(
700+
isinstance(to_value, torch.dtype)
701+
or to_value in {"float16", "bfloat16", "float32", "float64"}
702+
)
703+
and hasattr(value, "dtype")
704+
and value.dtype in {torch.int32, torch.int64, torch.int8, torch.int16}
705+
):
706+
# int vector should not be changed.
707+
return value
695708
return value.to(to_value)
696709
if isinstance(value, list):
697710
return [to_any(t, to_value) for t in value]
@@ -701,8 +714,6 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
701714
return {to_any(t, to_value) for t in value}
702715
if isinstance(value, dict):
703716
return {k: to_any(t, to_value) for k, t in value.items()}
704-
if hasattr(value, "to"):
705-
return value.to(to_value)
706717
if value.__class__.__name__ == "DynamicCache":
707718
return make_dynamic_cache(
708719
list(
@@ -712,11 +723,23 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
712723
)
713724
)
714725
)
726+
if value.__class__.__name__ == "EncoderDecoderCache":
727+
return make_encoder_decoder_cache(
728+
to_any(value.self_attention_cache, to_value),
729+
to_any(value.cross_attention_cache, to_value),
730+
)
715731
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
716732
args, spec = torch.utils._pytree.tree_flatten(value)
717733
new_args = to_any(args, to_value)
718734
return torch.utils._pytree.tree_unflatten(new_args, spec)
719735

736+
if hasattr(value, "to"):
737+
return value.to(to_value)
738+
739+
assert "Cache" not in value.__class__.__name__, (
740+
f"Class {value.__class__.__name__!r} should be registered "
741+
f"to be able to change the type in every tensor it contains."
742+
)
720743
assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
721744
return value
722745

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
mixture_of_expert,
99
object_detection,
1010
sentence_similarity,
11+
summarization,
1112
text_classification,
1213
text_generation,
1314
text2text_generation,
@@ -23,6 +24,7 @@
2324
mixture_of_expert,
2425
object_detection,
2526
sentence_similarity,
27+
summarization,
2628
text_classification,
2729
text_generation,
2830
text2text_generation,

0 commit comments

Comments
 (0)