Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/check-urls.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ jobs:
print_all: false
timeout: 2
retry_count# : 2
exclude_urls: https://github.com/pytorch/pytorch/pull/117009,https://github.com/huggingface/transformers/pull/29285,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/pytorch/pytorch/pull/,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://huggingface.co/,https://huggingface.co/
exclude_urls: https://github.com/pytorch/pytorch/pull/117009,https://github.com/huggingface/transformers/pull/29285,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://github.com/huggingface/transformers/pull/36652
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/pytorch/pytorch/pull/,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://huggingface.co/,https://huggingface.co/,https://github.com/huggingface/transformers/
# force_pass : true

- name: urls-checker-docs
Expand Down
14 changes: 8 additions & 6 deletions _doc/recipes/plot_dynamic_shapes_nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,23 @@ def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
return mask_left & mask_right

def forward(self, x):
return self.adaptive_enc_mask(x.shape[1], [])
def forward(self, x, y):
return self.adaptive_enc_mask(
x.shape[1], torch.tensor([], dtype=torch.int64), left_window=y.shape[0]
)


model = Model()
x = torch.rand((5, 8))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")
x, y = torch.rand((2, 546)), torch.rand((18,))
z = model(x, y)
print(f"y.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")

# %%
# Export
# ++++++

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
ep = torch.export.export(model, (x, y), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN}))
print(ep)


Expand Down
19 changes: 9 additions & 10 deletions onnx_diagnostic/helpers/torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def _forward_(*args, _f=None, _context=None, **kwargs):
def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False):
"""
The necessary modification to steem forward method and prints out inputs
and outputs. See example :ref:`l-plot-tiny-llm-export`.
and outputs using :func:`onnx_diagnostic.helpers.string_type`.
See example :ref:`l-plot-tiny-llm-export`.
"""
context = dict(
iteration=0,
Expand All @@ -58,7 +59,10 @@ def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max:


def is_torchdynamo_exporting() -> bool:
"""Tells if torch is exporting a model."""
"""
Tells if :epkg:`torch` is exporting a model.
Relies on ``torch.compiler.is_exporting()``.
"""
import torch

if not hasattr(torch.compiler, "is_exporting"):
Expand All @@ -77,7 +81,7 @@ def is_torchdynamo_exporting() -> bool:


def to_numpy(tensor: "torch.Tensor"): # noqa: F821
"""Converts a torch tensor to numy."""
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
try:
return tensor.numpy()
except TypeError:
Expand Down Expand Up @@ -309,10 +313,7 @@ def forward(self, input_ids):


def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
"""
Applies torch.to is applicable.
Goes recursively.
"""
"""Applies torch.to if applicable. Goes recursively."""
if isinstance(value, (torch.nn.Module, torch.Tensor)):
return value.to(to_value)
if isinstance(value, list):
Expand Down Expand Up @@ -344,9 +345,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:


def torch_deepcopy(value: Any) -> Any:
"""
Makes a deepcopy.
"""
"""Makes a deepcopy."""
if value is None:
return None
if isinstance(value, (int, float, str)):
Expand Down
Loading