Skip to content

Commit f0c543c

Browse files
committed
mypy"
git push "
1 parent 6a0f91f commit f0c543c

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

onnx_diagnostic/helpers/helper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,10 +1204,10 @@ def max_diff(
12041204
else:
12051205
for k, v in d["rep"].items():
12061206
drep[k] += v
1207-
if "dev" in d:
1207+
if "dev" in d and d["dev"] is not None:
12081208
if dd is None:
12091209
dd = d["dev"]
1210-
elif d["dev"] is not None:
1210+
else:
12111211
dd += d["dev"]
12121212

12131213
res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
@@ -1263,7 +1263,7 @@ def max_diff(
12631263
if _index < begin or (end != -1 and _index >= end):
12641264
# out of boundary
12651265
res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1266-
if dev:
1266+
if dev is not None:
12671267
res["dev"] = dev
12681268
return res
12691269
if isinstance(expected, (int, float)):
@@ -1279,7 +1279,7 @@ def max_diff(
12791279
n=1,
12801280
dnan=0,
12811281
)
1282-
if dev:
1282+
if dev is not None:
12831283
res["dev"] = dev
12841284
return res
12851285
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
@@ -1361,7 +1361,7 @@ def max_diff(
13611361
res: Dict[str, float] = dict( # type: ignore
13621362
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
13631363
)
1364-
if dev:
1364+
if dev is not None:
13651365
res["dev"] = dev
13661366
if hist:
13671367
if isinstance(hist, bool):

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,17 @@ def prepare_args_kwargs(
172172
return new_args, new_kwargs
173173

174174

175-
def post_process_run_aligned_obs(obs: Dict[str, Any]) -> Dict[str, Union[str, float, int]]:
175+
def post_process_run_aligned_obs(
176+
obs: Tuple[
177+
Optional[int],
178+
Optional[int],
179+
Optional[str],
180+
Optional[str],
181+
Optional[str],
182+
Optional[str],
183+
Dict[str, Union[int, float]],
184+
],
185+
) -> Dict[str, Union[str, float, int]]:
176186
"""
177187
Flattens an observations produced by function
178188
:func:`onnx_diagnostic.torch_onnx.sbs.run_aligned`.

0 commit comments

Comments
 (0)