Skip to content

Commit ce71af9

Browse files
committed
add test placeholder
1 parent 212a4f9 commit ce71af9

File tree

3 files changed

+93
-4
lines changed

3 files changed

+93
-4
lines changed

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,6 @@ def process(self, pickled_partition: bytearray):
755755
if isinstance(result, list):
756756
yield from result
757757
else:
758-
yield from list(reader.read(partition))
759-
break
758+
yield result
760759

761760
return UDTFIngestion

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,10 @@ def load(self, path: Optional[str] = None, _emit_ast: bool = True) -> DataFrame:
755755
return self.dbapi(**{k.lower(): v for k, v in self._cur_options.items()})
756756
if format_str == "jdbc":
757757
return self.jdbc(**{k.lower(): v for k, v in self._cur_options.items()})
758+
if format_str in self._custom_data_source_format:
759+
return self._custom_data_source(
760+
format_str, **{k.lower(): v for k, v in self._cur_options.items()}
761+
)
758762

759763
loader = getattr(self, self._format, None)
760764
if loader is not None:
@@ -2153,8 +2157,8 @@ def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
21532157
return dataframe
21542158

21552159
def register_custom_data_source(self, data_source: DataSource):
2156-
self._data_source_format.append(data_source.name())
2157-
self._custom_data_source_format[data_source.name()] = data_source
2160+
self._data_source_format.append(data_source.name().lower())
2161+
self._custom_data_source_format[data_source.name().lower()] = data_source
21582162

21592163
def _custom_data_source(
21602164
self,
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#
2+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from dataclasses import dataclass
6+
7+
import pytest
8+
9+
from snowflake.snowpark.types import StructType
10+
from tests.utils import RUNNING_ON_JENKINS
11+
12+
from snowflake.snowpark import InputPartition, DataSourceReader, DataSource
13+
from tests.parameters import MONGODB_CONNECTION_PARAMETERS
14+
15+
DEPENDENCIES_PACKAGE_UNAVAILABLE = True
16+
try:
17+
import pymongo # noqa: F401
18+
import pandas # noqa: F401
19+
20+
DEPENDENCIES_PACKAGE_UNAVAILABLE = False
21+
except ImportError:
22+
pass
23+
24+
25+
pytestmark = [
26+
pytest.mark.skipif(DEPENDENCIES_PACKAGE_UNAVAILABLE, reason="Missing 'pymongo'"),
27+
pytest.mark.skipif(
28+
RUNNING_ON_JENKINS, reason="cannot access external datasource from jenkins"
29+
),
30+
pytest.mark.skipif(
31+
"config.getoption('local_testing_mode', default=False)",
32+
reason="feature not available in local testing",
33+
),
34+
]
35+
36+
37+
# custom data source definition
38+
39+
40+
class MongoDbFakeDataSourceReader(DataSourceReader):
41+
def __init__(self, schema) -> None:
42+
super().__init__(schema)
43+
self.schema: StructType = schema
44+
45+
def partitions(self):
46+
return [AgeInputPartition(25), AgeInputPartition(35)]
47+
48+
def read(self, partition):
49+
from pymongo.mongo_client import MongoClient
50+
from pymongo.server_api import ServerApi
51+
52+
uri = MONGODB_CONNECTION_PARAMETERS["uri"]
53+
54+
client = MongoClient(uri, server_api=ServerApi("1"))
55+
res = []
56+
collection = client["my_test_db"]["my_collection"]
57+
document = collection.find({"age": partition.age})
58+
for doc in document:
59+
res.append((doc["name"], doc["age"], doc["city"]))
60+
61+
yield res
62+
63+
64+
@dataclass
65+
class AgeInputPartition(InputPartition):
66+
age: int
67+
68+
69+
class MongoDbFakeDataSource(DataSource):
70+
"""
71+
An example data source for batch query using the `faker` library.
72+
"""
73+
74+
@classmethod
75+
def name(cls):
76+
return "mongodb_test"
77+
78+
def schema(self):
79+
return "name string, age int, city string"
80+
81+
def reader(self, schema: StructType):
82+
return MongoDbFakeDataSourceReader(schema)
83+
84+
85+
def test_custom_mongodb_data_source(session):
86+
pass

0 commit comments

Comments
 (0)