Skip to content

Commit a5f0c6e

Browse files
authored
Improve documentation (#87)
* fix doc * Improve documentation * ci
1 parent 78ee08c commit a5f0c6e

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ jobs:
3030
print_all: false
3131
timeout: 2
3232
retry_count# : 2
33-
exclude_urls: https://github.com/pytorch/pytorch/pull/117009,https://github.com/huggingface/transformers/pull/29285,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475
34-
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/pytorch/pytorch/pull/,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://huggingface.co/,https://huggingface.co/
33+
exclude_urls: https://github.com/pytorch/pytorch/pull/117009,https://github.com/huggingface/transformers/pull/29285,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://github.com/huggingface/transformers/pull/36652
34+
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/pytorch/pytorch/pull/,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://huggingface.co/,https://huggingface.co/,https://github.com/huggingface/transformers/
3535
# force_pass : true
3636

3737
- name: urls-checker-docs

_doc/recipes/plot_dynamic_shapes_nonzero.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,23 @@ def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=
3535
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
3636
return mask_left & mask_right
3737

38-
def forward(self, x):
39-
return self.adaptive_enc_mask(x.shape[1], [])
38+
def forward(self, x, y):
39+
return self.adaptive_enc_mask(
40+
x.shape[1], torch.tensor([], dtype=torch.int64), left_window=y.shape[0]
41+
)
4042

4143

4244
model = Model()
43-
x = torch.rand((5, 8))
44-
y = model(x)
45-
print(f"x.shape={x.shape}, y.shape={y.shape}")
45+
x, y = torch.rand((2, 546)), torch.rand((18,))
46+
z = model(x, y)
47+
print(f"y.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
4648

4749
# %%
4850
# Export
4951
# ++++++
5052

5153
DYN = torch.export.Dim.DYNAMIC
52-
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
54+
ep = torch.export.export(model, (x, y), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN}))
5355
print(ep)
5456

5557

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def _forward_(*args, _f=None, _context=None, **kwargs):
3939
def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False):
4040
"""
4141
The necessary modification to steem forward method and prints out inputs
42-
and outputs. See example :ref:`l-plot-tiny-llm-export`.
42+
and outputs using :func:`onnx_diagnostic.helpers.string_type`.
43+
See example :ref:`l-plot-tiny-llm-export`.
4344
"""
4445
context = dict(
4546
iteration=0,
@@ -58,7 +59,10 @@ def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max:
5859

5960

6061
def is_torchdynamo_exporting() -> bool:
61-
"""Tells if torch is exporting a model."""
62+
"""
63+
Tells if :epkg:`torch` is exporting a model.
64+
Relies on ``torch.compiler.is_exporting()``.
65+
"""
6266
import torch
6367

6468
if not hasattr(torch.compiler, "is_exporting"):
@@ -77,7 +81,7 @@ def is_torchdynamo_exporting() -> bool:
7781

7882

7983
def to_numpy(tensor: "torch.Tensor"): # noqa: F821
80-
"""Converts a torch tensor to numy."""
84+
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
8185
try:
8286
return tensor.numpy()
8387
except TypeError:
@@ -309,10 +313,7 @@ def forward(self, input_ids):
309313

310314

311315
def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
312-
"""
313-
Applies torch.to is applicable.
314-
Goes recursively.
315-
"""
316+
"""Applies torch.to if applicable. Goes recursively."""
316317
if isinstance(value, (torch.nn.Module, torch.Tensor)):
317318
return value.to(to_value)
318319
if isinstance(value, list):
@@ -344,9 +345,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
344345

345346

346347
def torch_deepcopy(value: Any) -> Any:
347-
"""
348-
Makes a deepcopy.
349-
"""
348+
"""Makes a deepcopy."""
350349
if value is None:
351350
return None
352351
if isinstance(value, (int, float, str)):

0 commit comments

Comments
 (0)