diff --git a/_doc/api/export/index.rst b/_doc/api/export/index.rst index 2eef5455..b1a8e5d5 100644 --- a/_doc/api/export/index.rst +++ b/_doc/api/export/index.rst @@ -10,7 +10,7 @@ onnx_diagnostic.export ModelInputs +++++++++++ -.. autoclass:: onnx_diagnostic.dyanmic_shapes.ModelInputs +.. autoclass:: onnx_diagnostic.export.ModelInputs :members: Other functions diff --git a/_doc/conf.py b/_doc/conf.py index f5503109..3203df41 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -121,6 +121,7 @@ ("py:class", "transformers.cache_utils.MambaCache"), ("py:func", "torch.export._draft_export.draft_export"), ("py:func", "torch._export.tools.report_exportability"), + ("py:meth", "transformers.GenerationMixin.generate"), ] nitpick_ignore_regex = [ diff --git a/_doc/examples/plot_export_tiny_llm.py b/_doc/examples/plot_export_tiny_llm.py index 73e86344..8785cf13 100644 --- a/_doc/examples/plot_export_tiny_llm.py +++ b/_doc/examples/plot_export_tiny_llm.py @@ -113,10 +113,6 @@ def _forward_(*args, _f=None, **kwargs): print("result type", string_type(expected_output, with_shape=True)) -ep = torch.export.export( - untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes -) - # %% # It works. # diff --git a/_doc/examples/plot_export_with_dynamic_cache.py b/_doc/examples/plot_export_with_dynamic_cache.py index 2499e985..98da9efe 100644 --- a/_doc/examples/plot_export_with_dynamic_cache.py +++ b/_doc/examples/plot_export_with_dynamic_cache.py @@ -5,7 +5,7 @@ Export with DynamicCache and dynamic shapes =========================================== -Every LLMs implemented in :epkg:`trasnformers` use cache. +Every LLMs implemented in :epkg:`transformers` use cache. One of the most used is :class:`transformers.cache_utils.DynamicCache`. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 5fbc2175..bb952ebe 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -45,15 +45,15 @@ def forward(self, x, y): ds = mi.guess_dynamic_shapes() pprint.pprint(ds) - import pprint - import torch - from onnx_diagnostic.export import ModelInputs - **kwargs** .. runpython:: :showcode: + import pprint + import torch + from onnx_diagnostic.export import ModelInputs + class Model(torch.nn.Module): def forward(self, x, y): return x + y @@ -69,15 +69,15 @@ def forward(self, x, y): ds = mi.guess_dynamic_shapes() pprint.pprint(ds) - import pprint - import torch - from onnx_diagnostic.export import ModelInputs - **and and kwargs** .. runpython:: :showcode: + import pprint + import torch + from onnx_diagnostic.export import ModelInputs + class Model(torch.nn.Module): def forward(self, x, y): return x + y