Skip to content

Commit 9e55240

Browse files
SNOW-2288609: Support Index.get_level_values(). (#3696)
Add support for `Index.get_level_values()` so that Modin can use it to implement `DataFrame.query()`. Since we only support single-level indexes for now, as long as `level` is valid, we simply return the entire index. We raise an error that matches pandas's for invalid `level`. Signed-off-by: sfc-gh-mvashishtha <[email protected]>
1 parent c50cdeb commit 9e55240

File tree

6 files changed

+299
-17
lines changed

6 files changed

+299
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
`Series.to_snowpark()`, and `Series.to_snowflake()` on the "Pandas" and "Ray"
3535
backends. Previously, only some of these functions and methods were supported
3636
on the Pandas backend.
37+
- Added support for `Index.get_level_values()`.
3738

3839
#### Improvements
3940
- Set the default transfer limit in hybrid execution for data leaving Snowflake to 100k, which can be overridden with the SnowflakePandasTransferThreshold environment variable. This configuration is appropriate for scenarios with two available engines, "Pandas" and "Snowflake" on relational workloads.

docs/source/modin/supported/index_supported.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ Methods
190190
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
191191
| ``get_indexer_non_unique`` | N | | |
192192
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
193-
| ``get_level_values`` | N | | |
193+
| ``get_level_values`` | Y | | |
194194
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
195195
| ``get_loc`` | N | | |
196196
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

src/snowflake/snowpark/modin/plugin/docstrings/index.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,7 +1971,7 @@ def get_level_values():
19711971
19721972
Parameters
19731973
----------
1974-
level : int or str
1974+
level : Any
19751975
It is either the integer position or the name of the level.
19761976
19771977
Returns
@@ -1981,7 +1981,16 @@ def get_level_values():
19811981
19821982
Notes
19831983
-----
1984-
For Index, level should be 0, since there are no multiple levels.
1984+
For Index, level should be 0, -1, or the name of the index, since there
1985+
is only one level.
1986+
1987+
Examples
1988+
--------
1989+
>>> idx = pd.Index(['a', 'b', 'c'], name='index')
1990+
>>> idx.get_level_values(0)
1991+
Index(['a', 'b', 'c'], dtype='object', name='index')
1992+
>>> idx.get_level_values('index')
1993+
Index(['a', 'b', 'c'], dtype='object', name='index')
19851994
"""
19861995

19871996
def isin():

src/snowflake/snowpark/modin/plugin/extensions/index.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -876,10 +876,21 @@ def _get_indexer_strict(self, key: Any, axis_name: str) -> tuple[Index, np.ndarr
876876
tup = self.to_pandas()._get_indexer_strict(key=key, axis_name=axis_name)
877877
return self.__constructor__(tup[0]), tup[1]
878878

879-
@index_not_implemented()
880-
def get_level_values(self, level: int | str) -> Index:
881-
WarningMessage.index_to_pandas_warning("get_level_values")
882-
return self.__constructor__(self.to_pandas().get_level_values(level=level))
879+
def get_level_values(self, level: Any) -> Index:
880+
if self.nlevels > 1:
881+
ErrorMessage.not_implemented_error(
882+
"get_level_values() is not supported for MultiIndex"
883+
) # pragma: no cover
884+
if isinstance(level, int):
885+
if level not in (0, -1):
886+
raise IndexError(
887+
f"Too many levels: Index has only 1 level, not {level + 1}"
888+
)
889+
elif not (level is self.name or level == self.name):
890+
raise KeyError(
891+
f"Requested level ({level}) does not match index name ({self.name})"
892+
)
893+
return self
883894

884895
@index_not_implemented()
885896
def isin(self) -> None:

tests/integ/modin/index/test_index_methods.py

Lines changed: 271 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as native_pd
88
import pytest
9+
from pytest import param
910
from numpy.testing import assert_equal
1011
from pandas._libs import lib
1112

@@ -120,15 +121,276 @@ def test_index_intersection():
120121
assert_index_equal(diff, pd.Index([3, 4], dtype="int64"))
121122

122123

123-
@sql_count_checker(query_count=0)
124-
@pytest.mark.parametrize("native_index", NATIVE_INDEX_TEST_DATA)
125-
def test_index_get_level_values(native_index):
126-
snow_index = pd.Index(native_index)
127-
with pytest.raises(
128-
NotImplementedError,
129-
match="Snowpark pandas does not yet support the method Index.get_level_values",
130-
):
131-
assert_index_equal(snow_index.get_level_values(0), snow_index)
124+
@pytest.mark.parametrize(
125+
"native_class, snow_class",
126+
[
127+
param(native_pd.DatetimeIndex, pd.DatetimeIndex, id="DatetimeIndex"),
128+
param(native_pd.TimedeltaIndex, pd.TimedeltaIndex, id="TimedeltaIndex"),
129+
param(native_pd.Index, pd.Index, id="Index"),
130+
],
131+
)
132+
@pytest.mark.parametrize(
133+
"values",
134+
[
135+
param(tuple(), id="empty"),
136+
param([1], id="not_empty"),
137+
],
138+
)
139+
class TestGetLevelValues:
140+
@sql_count_checker(query_count=1)
141+
@pytest.mark.parametrize(
142+
"name, level",
143+
(
144+
(None, 0),
145+
(None, None),
146+
(None, -1),
147+
(0, 0),
148+
(1, 0),
149+
(1, 1.0),
150+
(1, np.int64(1)),
151+
(1, -1),
152+
(2, 0),
153+
(2, -1),
154+
("name", "name"),
155+
("name", 0),
156+
("name", -1),
157+
(pd.Timedelta(1), pd.Timedelta(1)),
158+
(pd.Timestamp("1994-07-29"), pd.Timestamp("1994-07-29")),
159+
(("name",), ("name",)),
160+
(3.5, 3.5),
161+
param(
162+
True,
163+
True,
164+
marks=pytest.mark.skip(
165+
reason="https://github.com/pandas-dev/pandas/issues/62169"
166+
),
167+
),
168+
param(
169+
False,
170+
False,
171+
marks=pytest.mark.skip(
172+
reason="https://github.com/pandas-dev/pandas/issues/62169"
173+
),
174+
),
175+
param(
176+
np.nan,
177+
np.nan,
178+
marks=pytest.mark.xfail(
179+
strict=True,
180+
raises=KeyError,
181+
reason="https://github.com/pandas-dev/pandas/issues/62169",
182+
),
183+
),
184+
param(
185+
pd.NaT,
186+
pd.NaT,
187+
marks=pytest.mark.xfail(
188+
strict=True,
189+
raises=KeyError,
190+
reason="https://github.com/pandas-dev/pandas/issues/62169",
191+
),
192+
),
193+
param(
194+
pd.NA,
195+
pd.NA,
196+
marks=pytest.mark.xfail(
197+
strict=True, raises=TypeError, reason="SNOW-2288761"
198+
),
199+
),
200+
),
201+
ids=str,
202+
)
203+
def test_valid_level(self, values, native_class, snow_class, name, level):
204+
if native_class is native_pd.TimedeltaIndex and len(values) == 0:
205+
pytest.xfail("SNOW-2288735")
206+
eval_snowpark_pandas_result(
207+
snow_class(values, name=name),
208+
native_class(values, name=name),
209+
lambda x: x.get_level_values(level),
210+
)
211+
212+
@pytest.mark.parametrize(
213+
"name, level",
214+
[
215+
(None, 1),
216+
(None, "name"),
217+
(None, pd.Timedelta(1)),
218+
(None, pd.Timestamp("1994-07-29")),
219+
param(
220+
None,
221+
True,
222+
marks=pytest.mark.skip(
223+
reason="https://github.com/pandas-dev/pandas/issues/62169"
224+
),
225+
),
226+
param(
227+
None,
228+
False,
229+
marks=pytest.mark.skip(
230+
reason="https://github.com/pandas-dev/pandas/issues/62169"
231+
),
232+
),
233+
(None, np.nan),
234+
(None, pd.NA),
235+
(1, 1),
236+
(1, "name"),
237+
(1, pd.Timedelta(1)),
238+
(1, pd.Timestamp("1994-07-29")),
239+
param(
240+
1,
241+
True,
242+
marks=pytest.mark.skip(
243+
reason="https://github.com/pandas-dev/pandas/issues/62169"
244+
),
245+
),
246+
param(
247+
1,
248+
False,
249+
marks=pytest.mark.skip(
250+
reason="https://github.com/pandas-dev/pandas/issues/62169"
251+
),
252+
),
253+
(1, None),
254+
(1, np.nan),
255+
(1, pd.NA),
256+
("name", "other_name"),
257+
("name", 1),
258+
("name", pd.Timedelta(1)),
259+
("name", None),
260+
param(
261+
"name",
262+
True,
263+
marks=pytest.mark.skip(
264+
reason="https://github.com/pandas-dev/pandas/issues/62169"
265+
),
266+
),
267+
param(
268+
"name",
269+
False,
270+
marks=pytest.mark.skip(
271+
reason="https://github.com/pandas-dev/pandas/issues/62169"
272+
),
273+
),
274+
("name", None),
275+
("name", np.nan),
276+
("name", pd.NA),
277+
(pd.Timedelta(1), 1),
278+
(pd.Timedelta(1), 1.0),
279+
(pd.Timedelta(1), "name"),
280+
(pd.Timedelta(1), pd.Timedelta(2)),
281+
(pd.Timedelta(1), pd.Timestamp("1994-07-29")),
282+
param(
283+
pd.Timedelta(1),
284+
True,
285+
marks=pytest.mark.skip(
286+
reason="https://github.com/pandas-dev/pandas/issues/62169"
287+
),
288+
),
289+
param(
290+
pd.Timedelta(1),
291+
False,
292+
marks=pytest.mark.skip(
293+
reason="https://github.com/pandas-dev/pandas/issues/62169"
294+
),
295+
),
296+
(pd.Timedelta(1), None),
297+
(pd.Timedelta(1), np.nan),
298+
(pd.Timedelta(1), pd.NA),
299+
(pd.Timestamp("1994-07-29"), 1),
300+
(pd.Timestamp("1994-07-29"), 1.0),
301+
(pd.Timestamp("1994-07-29"), "name"),
302+
(pd.Timestamp("1994-07-29"), pd.Timedelta(1)),
303+
param(
304+
pd.Timestamp("1994-07-29"),
305+
True,
306+
marks=pytest.mark.skip(
307+
reason="https://github.com/pandas-dev/pandas/issues/62169"
308+
),
309+
),
310+
param(
311+
pd.Timestamp("1994-07-29"),
312+
False,
313+
marks=pytest.mark.skip(
314+
reason="https://github.com/pandas-dev/pandas/issues/62169"
315+
),
316+
),
317+
(pd.Timestamp("1994-07-29"), np.nan),
318+
(pd.Timestamp("1994-07-29"), pd.NA),
319+
*(
320+
# pandas may or may not raise an error for these cases due to
321+
# https://github.com/pandas-dev/pandas/issues/62169
322+
param(
323+
*values,
324+
marks=pytest.mark.skip(
325+
reason="https://github.com/pandas-dev/pandas/issues/62169"
326+
),
327+
)
328+
for values in (
329+
(True, 1),
330+
(True, 1.0),
331+
(True, "name"),
332+
(True, pd.Timedelta(1)),
333+
(True, pd.Timestamp("1994-07-29")),
334+
(True, False),
335+
(True, None),
336+
(True, np.nan),
337+
(True, pd.NA),
338+
(False, 1),
339+
(False, 1.0),
340+
(False, "name"),
341+
(False, pd.Timedelta(1)),
342+
(False, pd.Timestamp("1994-07-29")),
343+
(False, True),
344+
(False, None),
345+
(False, np.nan),
346+
(False, pd.NA),
347+
(np.nan, 1),
348+
)
349+
),
350+
(np.nan, "name"),
351+
(np.nan, pd.Timedelta(1)),
352+
(np.nan, pd.Timestamp("1994-07-29")),
353+
(np.nan, True),
354+
param(
355+
np.nan,
356+
False,
357+
marks=pytest.mark.skip(
358+
reason="https://github.com/pandas-dev/pandas/issues/62169"
359+
),
360+
),
361+
(np.nan, None),
362+
(np.nan, pd.NA),
363+
*(
364+
param(
365+
*values,
366+
marks=pytest.mark.xfail(
367+
strict=True,
368+
raises=TypeError,
369+
reason="SNOW-2288761",
370+
),
371+
)
372+
for values in (
373+
(pd.NA, 1),
374+
(pd.NA, 1.0),
375+
(pd.NA, "name"),
376+
(pd.NA, pd.Timedelta(1)),
377+
(pd.NA, pd.Timestamp("1994-07-29")),
378+
(pd.NA, True),
379+
(pd.NA, False),
380+
(pd.NA, None),
381+
(pd.NA, np.nan),
382+
)
383+
),
384+
],
385+
)
386+
@sql_count_checker(query_count=0)
387+
def test_invalid_level(self, values, native_class, snow_class, name, level):
388+
eval_snowpark_pandas_result(
389+
snow_class(values, name=name),
390+
native_class(values, name=name),
391+
lambda index: index.get_level_values(level),
392+
expect_exception=True,
393+
)
132394

133395

134396
@sql_count_checker(query_count=0)

tests/integ/modin/test_unimplemented.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def test_unsupported_str_methods(func, func_name, caplog) -> None:
155155
lambda idx: idx.union(),
156156
lambda idx: idx.difference(),
157157
lambda idx: idx.get_indexer_for(),
158-
lambda idx: idx.get_level_values(),
159158
lambda idx: idx.slice_indexer(),
160159
lambda idx: idx.nbytes(),
161160
lambda idx: idx.memory_usage(),

0 commit comments

Comments
 (0)