Skip to content

Commit c7096b1

Browse files
committed
narrow input type of type_check
1 parent a483c73 commit c7096b1

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/zarr/core/type_check.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from typing_extensions import ReadOnly, evaluate_forward_ref
2121

2222

23-
class TypeResolutionError(Exception): ...
24-
25-
2623
@dataclass(frozen=True)
2724
class TypeCheckResult:
2825
"""
@@ -157,7 +154,9 @@ def _resolve_type_impl(
157154
return tp
158155

159156

160-
def check_type(obj: Any, expected_type: Any, path: str = "value") -> TypeCheckResult:
157+
def check_type(
158+
obj: Any, expected_type: type | types.UnionType | ForwardRef | None, path: str = "value"
159+
) -> TypeCheckResult:
161160
"""
162161
Check if `obj` is of type `expected_type`.
163162
"""
@@ -171,7 +170,7 @@ def check_type(obj: Any, expected_type: Any, path: str = "value") -> TypeCheckRe
171170
if expected_type is Any:
172171
return TypeCheckResult(True, [])
173172

174-
if origin is typing.Union or isinstance(expected_type, types.UnionType):
173+
if origin in (typing.Union, types.UnionType):
175174
return check_union(obj, expected_type, path)
176175

177176
if origin is typing.Literal:
@@ -201,11 +200,11 @@ def check_type(obj: Any, expected_type: Any, path: str = "value") -> TypeCheckRe
201200
return check_int(obj, path)
202201

203202
if expected_type in (float, str, bool):
204-
return check_primitive(obj, expected_type, path)
203+
return check_primitive(obj, expected_type, path) # type: ignore[arg-type]
205204

206205
# Fallback
207206
try:
208-
if isinstance(obj, expected_type):
207+
if isinstance(obj, expected_type): # type: ignore[arg-type]
209208
return TypeCheckResult(True, [])
210209
tn = _type_name(expected_type)
211210
return TypeCheckResult(False, [f"{path} expected {tn} but got {type(obj).__name__}"])

0 commit comments

Comments
 (0)