Skip to content

Commit f45dcfb

Browse files
authored
test: Make skip_requires_pyarrow compatible w/ pytest.param (#3772)
1 parent f1ae31e commit f45dcfb

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

tests/__init__.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,22 @@
55
import sys
66
from importlib.util import find_spec
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING, overload
99

1010
import pytest
1111

1212
from tests import examples_arguments_syntax, examples_methods_syntax
1313

1414
if TYPE_CHECKING:
15-
from collections.abc import Callable, Collection, Iterator, Mapping
15+
from collections.abc import Collection, Iterator, Mapping
1616
from re import Pattern
1717

1818
if sys.version_info >= (3, 11):
1919
from typing import TypeAlias
2020
else:
2121
from typing_extensions import TypeAlias
2222
from _pytest.mark import ParameterSet
23+
from _pytest.mark.structures import Markable
2324

2425
MarksType: TypeAlias = (
2526
"pytest.MarkDecorator | Collection[pytest.MarkDecorator | pytest.Mark]"
@@ -96,9 +97,21 @@ def windows_has_tzdata() -> bool:
9697
"""
9798

9899

100+
@overload
99101
def skip_requires_pyarrow(
100-
fn: Callable[..., Any] | None = None, /, *, requires_tzdata: bool = False
101-
) -> Callable[..., Any]:
102+
fn: None = ..., /, *, requires_tzdata: bool = ...
103+
) -> pytest.MarkDecorator: ...
104+
105+
106+
@overload
107+
def skip_requires_pyarrow(
108+
fn: Markable, /, *, requires_tzdata: bool = ...
109+
) -> Markable: ...
110+
111+
112+
def skip_requires_pyarrow(
113+
fn: Markable | None = None, /, *, requires_tzdata: bool = False
114+
) -> pytest.MarkDecorator | Markable:
102115
"""
103116
``pytest.mark.skipif`` decorator.
104117
@@ -109,7 +122,7 @@ def skip_requires_pyarrow(
109122
https://github.com/vega/altair/issues/3050
110123
111124
.. _pyarrow:
112-
https://pypi.org/project/pyarrow/
125+
https://pypi.org/project/pyarrow/
113126
"""
114127
composed = pytest.mark.skipif(
115128
find_spec("pyarrow") is None, reason="`pyarrow` not installed."
@@ -120,13 +133,7 @@ def skip_requires_pyarrow(
120133
reason="Timezone database is not installed on Windows",
121134
)(composed)
122135

123-
def wrap(test_fn: Callable[..., Any], /) -> Callable[..., Any]:
124-
return composed(test_fn)
125-
126-
if fn is None:
127-
return wrap
128-
else:
129-
return wrap(fn)
136+
return composed if fn is None else composed(fn)
130137

131138

132139
def id_func_str_only(val) -> str:

0 commit comments

Comments
 (0)