Skip to content

Commit 654fffe

Browse files
committed
mypy
1 parent 7fad7cf commit 654fffe

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,14 @@ def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
297297
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
298298
)
299299
mind = min(d0, d1)
300-
indices = [slice(None) for _ in range(rank)]
300+
indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
301301
indices[i] = slice(0, mind)
302302
ind = tuple(indices)
303303
new_tensor[ind] = tensor[ind]
304304
if d1 > mind:
305305
for k in range(d1 - mind):
306-
indices0 = [slice(None) for _ in range(rank)]
307-
indices1 = [slice(None) for _ in range(rank)]
306+
indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
307+
indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
308308
indices1[i] = mind + k
309309
indices0[i] = k % mind
310310
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]

onnx_diagnostic/export/validate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def compare_modules(
12-
modep: torch.export.ExportedProgram,
12+
modep: torch.nn.Module,
1313
mod: Optional[torch.nn.Module] = None,
1414
args: Optional[Tuple[Any, ...]] = None,
1515
kwargs: Optional[Dict[str, Any]] = None,
@@ -18,7 +18,7 @@ def compare_modules(
1818
verbose: int = 0,
1919
atol: float = 1e-2,
2020
rtol: float = 1e-1,
21-
) -> List[Dict[str, Any]]:
21+
) -> Dict[str, Any]:
2222
"""
2323
Compares two torch modules, usually one coming from an exported program,
2424
the other being the origin model.
@@ -150,8 +150,8 @@ def validate_ep(
150150
values = [_[1] for _ in items]
151151
all_vals = list(itertools.product(*values))
152152
cpl = CoupleInputsDynamicShapes(
153-
args,
154-
kwargs,
153+
args or (),
154+
kwargs or {},
155155
dynamic_shapes,
156156
args_names=(
157157
list(inspect.signature(modep.forward).parameters) if args and kwargs else None

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_pretrained_config(
6060
)
6161

6262

63-
def get_model_info(model_id) -> str:
63+
def get_model_info(model_id) -> Any:
6464
"""Returns the model info for a model_id."""
6565
return model_info(model_id)
6666

@@ -220,7 +220,11 @@ def enumerate_model_list(
220220
m.trending_score or "",
221221
m.private or "",
222222
m.gated or "",
223-
("|".join(m.tags)).replace(",", "_").replace(" ", "_"),
223+
(
224+
("|".join(m.tags)).replace(",", "_").replace(" ", "_")
225+
if m.tags
226+
else ""
227+
),
224228
],
225229
)
226230
)

0 commit comments

Comments
 (0)