|
3 | 3 | import asyncio |
4 | 4 | import pickle |
5 | 5 | import typing |
| 6 | +from contextlib import nullcontext |
6 | 7 |
|
7 | 8 | import numpy as np |
8 | 9 | import pandas as pd |
@@ -536,42 +537,67 @@ class OnlyOnesSchema(DataFrameModel): |
536 | 537 | a: Series[int] = Field(eq=1) |
537 | 538 |
|
538 | 539 |
|
539 | | -def test_check_types_arguments() -> None: |
| 540 | +@pytest.mark.parametrize( |
| 541 | + ("check_types_args", "df_return", "expected"), |
| 542 | + [ |
| 543 | + pytest.param( |
| 544 | + dict(), |
| 545 | + pd.DataFrame({"a": [0, 0]}), |
| 546 | + nullcontext(), |
| 547 | + id="validate entire df (2 rows)", |
| 548 | + ), |
| 549 | + pytest.param( |
| 550 | + dict(head=1), |
| 551 | + pd.DataFrame({"a": [0, 1]}), |
| 552 | + nullcontext(), |
| 553 | + id="validate header (row 1) - while row 2 is bad", |
| 554 | + ), |
| 555 | + pytest.param( |
| 556 | + dict(tail=1), |
| 557 | + pd.DataFrame({"a": [1, 0]}), |
| 558 | + nullcontext(), |
| 559 | + id="validate tail, (row 2) - while row 1 is bad", |
| 560 | + ), |
| 561 | + pytest.param( |
| 562 | + dict(lazy=True), |
| 563 | + pd.DataFrame({"a": [0, 0]}), |
| 564 | + nullcontext(), |
| 565 | + id="validate entire df (2 rows) - lazy mode ", |
| 566 | + ), |
| 567 | + pytest.param( |
| 568 | + dict(lazy=True), |
| 569 | + pd.DataFrame({"a": [1, 1]}), |
| 570 | + pytest.raises( |
| 571 | + errors.SchemaErrors, |
| 572 | + match=r"DATA", # error msg is specific for lazy- |
| 573 | + ), |
| 574 | + id="1's not allowed in schema - lazy mode", |
| 575 | + ), |
| 576 | + pytest.param( |
| 577 | + dict(), |
| 578 | + pd.DataFrame({"a": [1, 1]}), |
| 579 | + pytest.raises( |
| 580 | + errors.SchemaError, |
| 581 | + match=r"failed element-wise validator", # error msg is specific for regular-mode |
| 582 | + ), |
| 583 | + id="1's not allowed in schema - regular mode", |
| 584 | + ), |
| 585 | + ], |
| 586 | +) |
| 587 | +def test_check_types_arguments( |
| 588 | + check_types_args: dict, df_return: pd.DataFrame, expected |
| 589 | +) -> None: |
540 | 590 | """Test that check_types forwards key-words arguments to validate.""" |
541 | 591 | df = pd.DataFrame({"a": [0, 0]}) |
542 | 592 |
|
543 | | - @check_types() |
544 | | - def transform_empty_parenthesis( |
| 593 | + @check_types(**check_types_args) |
| 594 | + def transform_with_checks( |
545 | 595 | df: DataFrame[OnlyZeroesSchema], |
546 | 596 | ) -> DataFrame[OnlyZeroesSchema]: # pylint: disable=unused-argument |
547 | | - return df |
548 | | - |
549 | | - transform_empty_parenthesis(df) # type: ignore |
550 | | - |
551 | | - @check_types(head=1) |
552 | | - def transform_head( |
553 | | - df: DataFrame[OnlyZeroesSchema], # pylint: disable=unused-argument |
554 | | - ) -> DataFrame[OnlyZeroesSchema]: |
555 | | - return pd.DataFrame({"a": [0, 0]}) # type: ignore |
556 | | - |
557 | | - transform_head(df) # type: ignore |
558 | | - |
559 | | - @check_types(tail=1) |
560 | | - def transform_tail( |
561 | | - df: DataFrame[OnlyZeroesSchema], # pylint: disable=unused-argument |
562 | | - ) -> DataFrame[OnlyZeroesSchema]: |
563 | | - return pd.DataFrame({"a": [1, 0]}) # type: ignore |
564 | | - |
565 | | - transform_tail(df) # type: ignore |
| 597 | + return df_return |
566 | 598 |
|
567 | | - @check_types(lazy=True) |
568 | | - def transform_lazy( |
569 | | - df: DataFrame[OnlyZeroesSchema], # pylint: disable=unused-argument |
570 | | - ) -> DataFrame[OnlyZeroesSchema]: |
571 | | - return pd.DataFrame({"a": [1, 1]}) # type: ignore |
572 | | - |
573 | | - with pytest.raises(errors.SchemaErrors, match=r"DATA"): |
574 | | - transform_lazy(df) # type: ignore |
| 599 | + with expected: |
| 600 | + transform_with_checks(df) # type: ignore |
575 | 601 |
|
576 | 602 |
|
577 | 603 | def test_check_types_unchanged() -> None: |
@@ -754,6 +780,39 @@ def transform( |
754 | 780 | assert transform(None) is None |
755 | 781 |
|
756 | 782 |
|
| 783 | +@pytest.mark.parametrize( |
| 784 | + "callable_annotation", |
| 785 | + [ |
| 786 | + pytest.param(typing.Callable[[None], None], id="no args, no return"), |
| 787 | + pytest.param(typing.Callable[[None], int], id="no args, returns int"), |
| 788 | + pytest.param( |
| 789 | + typing.Callable[..., int], id="no info on args, returns int" |
| 790 | + ), |
| 791 | + pytest.param( |
| 792 | + typing.Callable[..., list[int]], |
| 793 | + id="no info on args, returns list of int", |
| 794 | + ), |
| 795 | + pytest.param( |
| 796 | + typing.Callable[[typing.Any], int], |
| 797 | + id="includes info on callable args", |
| 798 | + ), |
| 799 | + ], |
| 800 | +) |
| 801 | +def test_check_types_callables(callable_annotation: typing.Callable) -> None: |
| 802 | + """ |
| 803 | + Ensures `check_types` validates a dataframe, while passing in an additional callable argument |
| 804 | + """ |
| 805 | + |
| 806 | + class MySchema1(DataFrameModel): |
| 807 | + a: int |
| 808 | + |
| 809 | + @check_types |
| 810 | + def some_transformation(df: MySchema1, f: callable_annotation): # type: ignore[valid-type] |
| 811 | + pass |
| 812 | + |
| 813 | + _ = some_transformation(pd.DataFrame({"a": [1, 2]}), lambda x: 1) |
| 814 | + |
| 815 | + |
757 | 816 | def test_check_types_coerce() -> None: |
758 | 817 | """Test that check_types return the result of validate.""" |
759 | 818 |
|
|
0 commit comments