Skip to content

Commit 6fcbaf1

Browse files
committed
lint
1 parent f6a8e0f commit 6fcbaf1

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

.github/workflows/models.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python: ['3.13']
20-
transformers: ['5.0']
20+
transformers: ['4.57.6']
2121
torch: ['main']
2222
steps:
2323
- uses: actions/checkout@v3

.github/workflows/models448.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
matrix:
2020
os: [ubuntu-latest]
2121
python: ['3.13']
22-
transformers: ['5.0']
22+
transformers: ['4.57.6']
2323
torch: ['main']
2424
steps:
2525
- uses: actions/checkout@v3

onnx_diagnostic/investigate/input_observer.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _infer_dynamic_dimensions(
9696
class InputCandidate:
9797
"""Represents a consistence set of inputs for the exported method."""
9898

99-
def __init__(self, args: list[Any], kwargs: dict[str, Any], cloned: bool):
99+
def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any], cloned: bool):
100100
self.args = args
101101
self.kwargs = kwargs
102102
self.flat_list, self.spec = torch.utils._pytree.tree_flatten((args, kwargs))
@@ -114,7 +114,7 @@ def __init__(self, args: list[Any], kwargs: dict[str, Any], cloned: bool):
114114
)
115115

116116
self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None
117-
self.aligned_flat_list: list[torch.Tensor | None] = None
117+
self.aligned_flat_list: list[torch.Tensor | None] | None = None
118118

119119
def __str__(self) -> str:
120120
return (
@@ -152,13 +152,17 @@ def position_to_args_kwargs(self) -> list[int | str]:
152152
"""
153153
if self._position_to_args_kwargs is None:
154154
self.build_mappings()
155+
# type checking is missing it
156+
assert self._position_to_args_kwargs is not None
155157
return self._position_to_args_kwargs
156158

157159
@property
158-
def n_tensors_for_args_kwargs(self) -> list[int | str]:
160+
def n_tensors_for_args_kwargs(self) -> dict[int | str, int]:
159161
"""Returns the number of flat tensors in every args or kwargs."""
160162
if self._n_tensors_for_args_kwargs is None:
161163
self.build_mappings()
164+
# type checking is missing it
165+
assert self._n_tensors_for_args_kwargs is not None
162166
return self._n_tensors_for_args_kwargs
163167

164168
def _set_aligned_flat_list(
@@ -255,9 +259,7 @@ def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
255259
self.outputs_specs.append(spec)
256260
self.flat_outputs.append([t.clone().detach() for t in flat_res])
257261

258-
def align_inputs_none_values(
259-
self,
260-
) -> list[list[torch.Tensor]]:
262+
def align_inputs_none_values(self):
261263
"""Once the best candidate is chosen, this method aligns every set of inputs
262264
on the best candidate, it inserts None at the right position when
263265
optional inputs are not specified. We consider a set of inputs is aligned
@@ -283,16 +285,23 @@ def align_inputs_none_values(
283285
candidate.align_with(self._best_candidate, self._captured_inputs)
284286

285287
def infer_dynamic_shapes(
286-
self, set_batch_dimension_for: set[int | str] | None = None
288+
self, set_batch_dimension_for: set[int | str] | None = None, return_flat: bool = False
287289
) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
288290
"""Infers dynamic shapes. based on the collected tensors.
289291
Most of the time, models do support a batch dimension
290292
but this batch dimension has the same value for every input sample.
291293
Instead of running inference on new samples, argument `set_batch_dimension_for`
292294
can be used to tell the first dimension is a dynamic dimension for a particular
293295
set of inputs referenced by their name (str) or their position (int).
296+
297+
`return_flat` tells the function to return a flat tuple instead of
298+
nested structured.
294299
"""
295300
self.align_inputs_none_values()
301+
# type checking
302+
assert self._best_candidate is not None
303+
assert self._best_candidate.flat_list is not None
304+
assert self._best_candidate.aligned_flat_list is not None
296305

297306
def _set_batch_dimension(name_or_position):
298307
if not set_batch_dimension_for:
@@ -309,6 +318,8 @@ def _set_batch_dimension(name_or_position):
309318
return False
310319

311320
def _set_batch_dimension_for_flat_index(index):
321+
# type checking
322+
assert self._best_candidate is not None
312323
return _set_batch_dimension(self._best_candidate.position_to_args_kwargs[index])
313324

314325
if len(self._best_candidate.flat_list) != len(self._best_candidate.aligned_flat_list):
@@ -329,6 +340,7 @@ def _set_batch_dimension_for_flat_index(index):
329340
shape_lists = [
330341
[(None if t is None else t.shape) for t in candidate.aligned_flat_list]
331342
for candidate in self.inputs
343+
if candidate.aligned_flat_list is not None
332344
]
333345
n_tensors = len(shape_lists[0])
334346
dynamic_shapes = [
@@ -340,6 +352,8 @@ def _set_batch_dimension_for_flat_index(index):
340352
]
341353
cst = torch.export.Dim.DYNAMIC
342354
flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
355+
if return_flat:
356+
return tuple(flat_dynamic_shapes)
343357
if len(flat_dynamic_shapes) == len(self._best_candidate.args) + len(
344358
self._best_candidate.kwargs
345359
):
@@ -391,10 +405,9 @@ def infer_arguments(
391405
"""Infers arguments based on the collected tensors."""
392406
# This is already checked by _build_inputs_completed_with_none_values
393407
# but this is not always well captured by tools checking types.
394-
torch._check(
395-
self._best_candidate.args is not None and self._best_candidate.kwargs is not None,
396-
lambda: "No input was captured.",
397-
)
408+
torch._check(self._best_candidate is not None, lambda: "No input was captured.")
409+
# type checking
410+
assert self._best_candidate is not None
398411
candidate = None
399412
if index is None:
400413
for cand in self.inputs:
@@ -412,16 +425,25 @@ def infer_arguments(
412425
candidate = self.inputs[index]
413426

414427
torch._check(candidate is not None, "No input was captured.")
428+
# type checking
429+
assert candidate is not None
430+
assert candidate.aligned_flat_list is not None
415431

416432
aligned_flat_list = candidate.aligned_flat_list
417433
if any(t is None for t in aligned_flat_list):
418-
dynamic_shapes = self.infer_dynamic_shapes()
434+
dynamic_shapes = self.infer_dynamic_shapes(return_flat=True)
435+
# type checking
436+
assert isinstance(dynamic_shapes, tuple)
419437
aligned_flat_list = aligned_flat_list.copy()
420438
for index in range(len(aligned_flat_list)):
421439
if aligned_flat_list[index] is not None:
422440
continue
423441
shape = dynamic_shapes[index]
424-
all_non_empty_tensors = [c.aligned_flat_list[index] for c in self.inputs]
442+
all_non_empty_tensors = [
443+
c.aligned_flat_list[index]
444+
for c in self.inputs
445+
if c.aligned_flat_list is not None
446+
]
425447
all_non_empty_tensors = [t for t in all_non_empty_tensors if t is not None]
426448
if not all_non_empty_tensors:
427449
raise RuntimeError(
@@ -444,6 +466,9 @@ def infer_arguments(
444466
aligned_flat_list[index] = torch.empty(
445467
tuple(new_shape), dtype=tensor.dtype, device=tensor.device
446468
)
469+
# type checking
470+
assert candidate is not None
471+
assert candidate.aligned_spec is not None
447472
args, kwargs = torch.utils._pytree.tree_unflatten(
448473
aligned_flat_list, candidate.aligned_spec
449474
)

0 commit comments

Comments
 (0)