@@ -39,7 +39,8 @@ def _forward_(*args, _f=None, _context=None, **kwargs):
3939def steal_forward (model : torch .nn .Module , with_shape : bool = True , with_min_max : bool = False ):
4040 """
4141 The necessary modification to steem forward method and prints out inputs
42- and outputs. See example :ref:`l-plot-tiny-llm-export`.
42+ and outputs using :func:`onnx_diagnostic.helpers.string_type`.
43+ See example :ref:`l-plot-tiny-llm-export`.
4344 """
4445 context = dict (
4546 iteration = 0 ,
@@ -58,7 +59,10 @@ def steal_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max:
5859
5960
6061def is_torchdynamo_exporting () -> bool :
61- """Tells if torch is exporting a model."""
62+ """
63+ Tells if :epkg:`torch` is exporting a model.
64+ Relies on ``torch.compiler.is_exporting()``.
65+ """
6266 import torch
6367
6468 if not hasattr (torch .compiler , "is_exporting" ):
@@ -77,7 +81,7 @@ def is_torchdynamo_exporting() -> bool:
7781
7882
7983def to_numpy (tensor : "torch.Tensor" ): # noqa: F821
80- """Converts a torch tensor to numy ."""
84+ """Converts a :class:` torch.Tensor` to :class:`numpy.ndarray` ."""
8185 try :
8286 return tensor .numpy ()
8387 except TypeError :
@@ -309,10 +313,7 @@ def forward(self, input_ids):
309313
310314
311315def to_any (value : Any , to_value : Union [torch .dtype , torch .device ]) -> Any :
312- """
313- Applies torch.to is applicable.
314- Goes recursively.
315- """
316+ """Applies torch.to if applicable. Goes recursively."""
316317 if isinstance (value , (torch .nn .Module , torch .Tensor )):
317318 return value .to (to_value )
318319 if isinstance (value , list ):
@@ -344,9 +345,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
344345
345346
346347def torch_deepcopy (value : Any ) -> Any :
347- """
348- Makes a deepcopy.
349- """
348+ """Makes a deepcopy."""
350349 if value is None :
351350 return None
352351 if isinstance (value , (int , float , str )):
0 commit comments