Skip to content

Commit b0cb60f

Browse files
committed
feat(pyspark): add Pydantic integration tests with PySpark
- Implement tests for the integration between PySpark and Pydantic. - Create sample schema models and validate data using Pydantic. Signed-off-by: Ezequiel Leonardo Castaño <[email protected]>
1 parent 37dae8b commit b0cb60f

File tree

4 files changed

+247
-15
lines changed

4 files changed

+247
-15
lines changed

pandera/api/pyspark/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from pandera.api.pyspark.components import Column
44
from pandera.api.pyspark.container import DataFrameSchema
5+
from pandera.api.pyspark.model import DataFrameModel

pandera/typing/pyspark.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
"""Pandera type annotations for Pyspark Pandas."""
22

3-
from typing import TYPE_CHECKING, Generic, TypeVar
3+
import functools
4+
import json
5+
from typing import TYPE_CHECKING, Generic, TypeVar, Any, get_args
46

7+
from pydantic import GetCoreSchemaHandler
8+
from pydantic_core import core_schema
9+
10+
from pandera.engines import PYDANTIC_V2
11+
from pandera.errors import SchemaInitError
512
from pandera.typing.common import (
613
DataFrameBase,
714
GenericDtype,
815
IndexBase,
916
SeriesBase,
17+
_GenericAlias,
1018
)
11-
from pandera.typing.pandas import DataFrameModel, _GenericAlias
19+
from pandera.typing.pandas import DataFrameModel
1220

1321
try:
1422
import pyspark.pandas as ps
@@ -39,6 +47,76 @@ def __class_getitem__(cls, item):
3947
"""Define this to override's pyspark.pandas generic type."""
4048
return _GenericAlias(cls, item)
4149

50+
@classmethod
51+
def pydantic_validate(cls, obj: Any, schema_model: T) -> ps.DataFrame:
52+
"""
53+
Verify that the input can be converted into a pandas dataframe that
54+
meets all schema requirements.
55+
56+
This is for pydantic >= v2
57+
"""
58+
try:
59+
schema = schema_model.to_schema() # type: ignore[attr-defined]
60+
except SchemaInitError as exc:
61+
error_message = (
62+
f"Cannot use {cls} as a pydantic type as its "
63+
"DataFrameModel cannot be converted to a DataFrameSchema.\n"
64+
f"Please revisit the model to address the following errors:"
65+
f"\n{exc}"
66+
)
67+
raise ValueError(error_message) from exc
68+
69+
validated_data = schema.validate(obj)
70+
71+
if validated_data.pandera.errors:
72+
errors = json.dumps(
73+
dict(validated_data.pandera.errors), indent=4
74+
)
75+
raise ValueError(errors)
76+
77+
return validated_data
78+
79+
if PYDANTIC_V2:
80+
81+
@classmethod
82+
def __get_pydantic_core_schema__(
83+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
84+
) -> core_schema.CoreSchema:
85+
schema_model = get_args(_source_type)[0]
86+
return core_schema.no_info_plain_validator_function(
87+
functools.partial(
88+
cls.pydantic_validate,
89+
schema_model=schema_model,
90+
),
91+
)
92+
93+
else:
94+
95+
@classmethod
96+
def __get_validators__(cls):
97+
yield cls._pydantic_validate
98+
99+
@classmethod
100+
def _get_schema_model(cls, field):
101+
if not field.sub_fields:
102+
raise TypeError(
103+
"Expected a typed pandera.typing.DataFrame,"
104+
" e.g. DataFrame[Schema]"
105+
)
106+
schema_model = field.sub_fields[0].type_
107+
return schema_model
108+
109+
@classmethod
110+
def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame:
111+
"""
112+
Verify that the input can be converted into a pandas dataframe that
113+
meets all schema requirements.
114+
115+
This is for pydantic < v1
116+
"""
117+
schema_model = cls._get_schema_model(field)
118+
return cls.pydantic_validate(obj, schema_model)
119+
42120
# pylint:disable=too-few-public-methods,arguments-renamed
43121
class Series(SeriesBase, ps.Series, Generic[GenericDtype]): # type: ignore [misc] # noqa
44122
"""Representation of pandas.Series, only used for type annotation.

pandera/typing/pyspark_sql.py

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
"""Pandera type annotations for Pyspark."""
1+
"""Pandera type annotations for Pyspark SQL."""
22

3-
from typing import TypeVar, Union
3+
import functools
4+
import json
5+
from typing import Union, TypeVar, Any, get_args, Generic
46

5-
from pandera.typing.common import DataFrameBase
6-
from pandera.typing.pandas import DataFrameModel, _GenericAlias
7+
from pydantic import GetCoreSchemaHandler
8+
from pydantic_core import core_schema
9+
10+
from pandera.engines import pyspark_engine, PYDANTIC_V2
11+
from pandera.errors import SchemaInitError
12+
from pandera.typing.common import DataFrameBase, _GenericAlias
13+
from pandera.api.pyspark import DataFrameModel
714

815
try:
916
import pyspark.sql as ps
@@ -12,9 +19,9 @@
1219
except ImportError: # pragma: no cover
1320
PYSPARK_SQL_INSTALLED = False
1421

15-
if PYSPARK_SQL_INSTALLED:
16-
from pandera.engines import pyspark_engine
22+
T = TypeVar("T", bound=DataFrameModel)
1723

24+
if PYSPARK_SQL_INSTALLED:
1825
PysparkString = pyspark_engine.String
1926
PysparkInt = pyspark_engine.Int
2027
PysparkLongInt = pyspark_engine.BigInt
@@ -43,13 +50,6 @@
4350
PysparkBinary, # type: ignore
4451
],
4552
)
46-
from typing import TYPE_CHECKING, Generic
47-
48-
# pylint:disable=invalid-name
49-
if TYPE_CHECKING:
50-
T = TypeVar("T") # pragma: no cover
51-
else:
52-
T = DataFrameModel
5353

5454
if PYSPARK_SQL_INSTALLED:
5555
# pylint: disable=too-few-public-methods,arguments-renamed
@@ -64,3 +64,75 @@ class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]):
6464
def __class_getitem__(cls, item):
6565
"""Define this to override's pyspark.pandas generic type."""
6666
return _GenericAlias(cls, item) # pragma: no cover
67+
68+
@classmethod
69+
def pydantic_validate(
70+
cls, obj: ps.DataFrame, schema_model: T
71+
) -> ps.DataFrame:
72+
"""
73+
Verify that the input can be converted into a pandas dataframe that
74+
meets all schema requirements.
75+
76+
This is for pydantic V1 and V2.
77+
"""
78+
try:
79+
schema = schema_model.to_schema()
80+
except SchemaInitError as exc:
81+
error_message = (
82+
f"Cannot use {cls} as a pydantic type as its "
83+
"DataFrameModel cannot be converted to a DataFrameSchema.\n"
84+
f"Please revisit the model to address the following errors:"
85+
f"\n{exc}"
86+
)
87+
raise ValueError(error_message) from exc
88+
89+
validated_data = schema.validate(obj)
90+
91+
if validated_data.pandera.errors:
92+
errors = json.dumps(
93+
dict(validated_data.pandera.errors), indent=4
94+
)
95+
raise ValueError(errors)
96+
97+
return validated_data
98+
99+
if PYDANTIC_V2:
100+
101+
@classmethod
102+
def __get_pydantic_core_schema__(
103+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
104+
) -> core_schema.CoreSchema:
105+
schema_model = get_args(_source_type)[0]
106+
return core_schema.no_info_plain_validator_function(
107+
functools.partial(
108+
cls.pydantic_validate,
109+
schema_model=schema_model,
110+
),
111+
)
112+
113+
else:
114+
115+
@classmethod
116+
def __get_validators__(cls):
117+
yield cls._pydantic_validate
118+
119+
@classmethod
120+
def _get_schema_model(cls, field):
121+
if not field.sub_fields:
122+
raise TypeError(
123+
"Expected a typed pandera.typing.DataFrame,"
124+
" e.g. DataFrame[Schema]"
125+
)
126+
schema_model = field.sub_fields[0].type_
127+
return schema_model
128+
129+
@classmethod
130+
def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame:
131+
"""
132+
Verify that the input can be converted into a pandas dataframe that
133+
meets all schema requirements.
134+
135+
This is for pydantic v1
136+
"""
137+
schema_model = cls._get_schema_model(field)
138+
return cls.pydantic_validate(obj, schema_model)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Tests for the integration between PySpark and Pydantic."""
2+
3+
import pytest
4+
from pydantic import BaseModel, ValidationError
5+
from pyspark.testing.utils import assertDataFrameEqual
6+
import pyspark.sql.types as T
7+
8+
import pandera.pyspark as pa
9+
from pandera.typing.pyspark_sql import DataFrame as PySparkSQLDataFrame
10+
from pandera.typing.pyspark import DataFrame as PySparkDataFrame
11+
from pandera.pyspark import DataFrameModel
12+
13+
14+
@pytest.fixture
15+
def sample_schema_model():
16+
class SampleSchema(DataFrameModel):
17+
"""
18+
Sample schema model with data checks.
19+
"""
20+
21+
product: T.StringType() = pa.Field()
22+
price: T.IntegerType() = pa.Field()
23+
24+
return SampleSchema
25+
26+
27+
@pytest.fixture(
28+
params=[PySparkDataFrame, PySparkSQLDataFrame],
29+
ids=["pyspark", "pyspark_sql"],
30+
)
31+
def pydantic_container(request, sample_schema_model):
32+
TypingClass = request.param
33+
34+
class PydanticContainer(BaseModel):
35+
"""
36+
Pydantic container with a DataFrameModel as a field.
37+
"""
38+
39+
data: TypingClass[sample_schema_model]
40+
41+
return PydanticContainer
42+
43+
44+
@pytest.fixture
45+
def correct_data(spark, sample_data, sample_spark_schema):
46+
"""
47+
Correct data that should pass validation.
48+
"""
49+
return spark.createDataFrame(sample_data, sample_spark_schema)
50+
51+
52+
@pytest.fixture
53+
def incorrect_data(spark):
54+
"""
55+
Incorrect data that should fail validation.
56+
"""
57+
data = [
58+
(1, "Apples"),
59+
(2, "Bananas"),
60+
]
61+
return spark.createDataFrame(data, ["product", "price"])
62+
63+
64+
def test_pydantic_model_instantiates_with_correct_data(
65+
correct_data, pydantic_container
66+
):
67+
"""
68+
Test that a Pydantic model can be instantiated with a DataFrameModel when data is valid.
69+
"""
70+
my_container = pydantic_container(data=correct_data)
71+
assertDataFrameEqual(my_container.data, correct_data)
72+
73+
74+
def test_pydantic_model_throws_validation_error_with_incorrect_data(
75+
incorrect_data, pydantic_container
76+
):
77+
"""
78+
Test that a Pydantic model throws a ValidationError when data is invalid.
79+
"""
80+
with pytest.raises(ValidationError):
81+
pydantic_container(data=incorrect_data)

0 commit comments

Comments
 (0)