Skip to content

Commit 270e101

Browse files
committed
extend documentation
1 parent 0ef274f commit 270e101

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

_doc/status/exported_program_dynamic.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ with different options. This steps happens before converting into ONNX.
3838
print(name)
3939
print("=" * len(name))
4040
print()
41+
print(f"code: :class:`onnx_diagnostic.torch_export_patches.eval.model_cases.{name}`")
42+
print()
4143
print("forward")
4244
print("+++++++")
4345
print()

_doc/status/exporter_dynamic.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ Exported ONNX with Dynamic Shapes
33
=================================
44

55
The following script shows the exported program for many short cases
6-
and various l-plot-export-with-dynamic-shape to retrieve an ONNX model equivalent
7-
to the original model.
6+
to retrieve an ONNX model equivalent to the original model.
87

98
.. runpython::
109
:showcode:
@@ -40,6 +39,8 @@ to the original model.
4039
print(name)
4140
print("=" * len(name))
4241
print()
42+
print(f"code: :class:`onnx_diagnostic.torch_export_patches.eval.model_cases.{name}`")
43+
print()
4344
print("forward")
4445
print("+++++++")
4546
print()

onnx_diagnostic/torch_export_patches/eval/model_cases.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,34 @@ def forward(self, images, position):
570570
_dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}}
571571

572572

573+
class ControlFlowWhileDec(torch.nn.Module):
574+
def forward(self, ci, a, b):
575+
def cond_fn(i, x, y):
576+
return i > 0
577+
578+
def body_fn(i, x, y):
579+
return i - 1, x + y, y - x
580+
581+
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
582+
583+
_inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
584+
_dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
585+
586+
587+
class ControlFlowWhileInc(torch.nn.Module):
588+
def forward(self, ci, a, b):
589+
def cond_fn(i, x, y):
590+
return i < x.size(0)
591+
592+
def body_fn(i, x, y):
593+
return i + 1, x + y, y - x
594+
595+
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
596+
597+
_inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
598+
_dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
599+
600+
573601
class SignatureInt1(torch.nn.Module):
574602
def __init__(self, n_dims: int = 3, n_targets: int = 1):
575603
super().__init__()

0 commit comments

Comments
 (0)