Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1bd4269
[Data] Add map namespace support for expression operations
ryankert01 Jan 6, 2026
2e157bd
Merge branch 'master' into map-expression
ryankert01 Jan 6, 2026
68bef64
address ai review
ryankert01 Jan 6, 2026
fe2642b
fix cursor bot suggestions
ryankert01 Jan 6, 2026
843cac1
Merge branch 'master' into map-expression
ryankert01 Jan 6, 2026
f16bfd1
Merge remote-tracking branch 'origin/master' into map-expression
ryankert01 Jan 12, 2026
df1fe8c
refactor tests
ryankert01 Jan 12, 2026
6461062
Merge branch 'master' into map-expression
ryankert01 Jan 12, 2026
fcd3652
Merge branch 'master' into map-expression
ryankert01 Jan 18, 2026
202a652
Merge branch 'master' into map-expression
owenowenisme Jan 21, 2026
50a2e64
Update python/ray/data/namespace_expressions/map_namespace.py
ryankert01 Jan 22, 2026
e613cfa
address commits
ryankert01 Jan 22, 2026
70a3760
Merge branch 'master' into map-expression
ryankert01 Jan 22, 2026
49268ec
Merge branch 'master' into map-expression
ryankert01 Jan 25, 2026
c390a24
create 3 helper functions to make the intent clearer
ryankert01 Jan 25, 2026
5e024c8
use numpy.repeat()
ryankert01 Jan 25, 2026
10e4b7c
text extractioon on empty chunkedArray
ryankert01 Jan 25, 2026
f9d53b8
Merge branch 'master' into map-expression
ryankert01 Jan 25, 2026
978132e
lint
ryankert01 Jan 25, 2026
2eff519
Merge remote-tracking branch 'origin/map-expression' into map-expression
ryankert01 Jan 25, 2026
7a11478
Merge branch 'master' into map-expression
goutamvenkat-anyscale Feb 4, 2026
dae4645
Merge branch 'master' into map-expression
ryankert01 Feb 8, 2026
59f8047
address comments
ryankert01 Feb 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
if TYPE_CHECKING:
from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace
from ray.data.namespace_expressions.list_namespace import _ListNamespace
from ray.data.namespace_expressions.map_namespace import _MapNamespace
from ray.data.namespace_expressions.string_namespace import _StringNamespace
from ray.data.namespace_expressions.struct_namespace import _StructNamespace

Expand Down Expand Up @@ -646,6 +647,13 @@ def struct(self) -> "_StructNamespace":

return _StructNamespace(self)

@property
def map(self) -> "_MapNamespace":
"""Access map/dict operations for this expression."""
from ray.data.namespace_expressions.map_namespace import _MapNamespace

return _MapNamespace(self)

@property
def dt(self) -> "_DatetimeNamespace":
"""Access datetime operations for this expression."""
Expand Down Expand Up @@ -1493,6 +1501,7 @@ def download(uri_column_name: str) -> DownloadExpr:
"_ListNamespace",
"_StringNamespace",
"_StructNamespace",
"_MapNamespace",
"_DatetimeNamespace",
]

Expand All @@ -1511,6 +1520,10 @@ def __getattr__(name: str):
from ray.data.namespace_expressions.struct_namespace import _StructNamespace

return _StructNamespace
elif name == "_MapNamespace":
from ray.data.namespace_expressions.map_namespace import _MapNamespace

return _MapNamespace
elif name == "_DatetimeNamespace":
from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace

Expand Down
165 changes: 165 additions & 0 deletions python/ray/data/namespace_expressions/map_namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING

import pyarrow
import pyarrow.compute as pc

from ray.data.datatype import DataType
from ray.data.expressions import pyarrow_udf

if TYPE_CHECKING:
from ray.data.expressions import Expr, UDFExpr


class MapComponent(str, Enum):
KEYS = "keys"
VALUES = "values"


def _extract_map_component(
Copy link
Contributor

@goutamvenkat-anyscale goutamvenkat-anyscale Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create 3 helper functions to make the intent clearer.

  1. _get_child_array which gets keys and values
  2. _make_empty_list_array
  3. _rebuild_list_array with normalized offsets

For each one add an example of what it's doing in the comments.

def _extract_map_component(
    arr: pyarrow.Array, component: MapComponent
) -> pyarrow.Array:
    """Extract keys or values from a MapArray or ListArray<Struct>."""
    
    if isinstance(arr, pyarrow.ChunkedArray):
        return pyarrow.chunked_array(
            [_extract_map_component(chunk, component) for chunk in arr.chunks]
        )

    child_array = _get_child_array(arr, component)
    
    if child_array is None:
        return _make_empty_list_array(arr, component)
    
    return _rebuild_list_array(arr, child_array)

arr: pyarrow.Array, component: MapComponent
) -> pyarrow.Array:
"""
Extracts keys or values from a MapArray or ListArray<Struct>.

This serves as the primary implementation since PyArrow does not yet
expose dedicated compute kernels for map projection in the Python API.
"""
# 1. Handle Chunked Arrays (Recursion)
if isinstance(arr, pyarrow.ChunkedArray):
return pyarrow.chunked_array(
[_extract_map_component(chunk, component) for chunk in arr.chunks]
)

child_array = None

# Case 1: MapArray
if isinstance(arr, pyarrow.MapArray):
if component == MapComponent.KEYS:
child_array = arr.keys
else:
child_array = arr.items

# Case 2: ListArray<Struct<Key, Value>>
elif isinstance(arr, (pyarrow.ListArray, pyarrow.LargeListArray)):
flat_values = arr.values
if (
isinstance(flat_values, pyarrow.StructArray)
and flat_values.type.num_fields >= 2
):
idx = 0 if component == MapComponent.KEYS else 1
child_array = flat_values.field(idx)

if child_array is None:
# This can happen if the input array is not a supported map type.
# We allow this to proceed only if the array is empty or all-nulls,
# in which case we'll produce an empty or all-nulls output.
if len(arr) > 0 and arr.null_count < len(arr):
raise TypeError(
f"Expression is not a map type. .map.{component.value}() can only be "
f"called on MapArray or List<Struct<key, value>> types, but got {arr.type}."
)
return pyarrow.ListArray.from_arrays(
offsets=[0] * (len(arr) + 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

values=pyarrow.array([], type=pyarrow.null()),
mask=pyarrow.array([True] * len(arr)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

)

# Reconstruct ListArray & Normalize Offsets
offsets = arr.offsets
if len(offsets) > 0: # Handle offsets changes
start_offset = offsets[0]
if start_offset.as_py() != 0:
# Slice child_array to match normalized offsets
end_offset = offsets[-1]
child_array = child_array.slice(
offset=start_offset.as_py(), length=(end_offset - start_offset).as_py()
)
offsets = pc.subtract(offsets, start_offset)

return pyarrow.ListArray.from_arrays(
offsets=offsets, values=child_array, mask=arr.is_null()
)


@dataclass
class _MapNamespace:
"""Namespace for map operations on expression columns.

This namespace provides methods for operating on map-typed columns
(including MapArrays and ListArrays of Structs) using PyArrow UDFs.

Example:
>>> from ray.data.expressions import col
>>> # Get keys from map column
>>> expr = col("headers").map.keys()
>>> # Get values from map column
>>> expr = col("headers").map.values()
"""

_expr: "Expr"

def keys(self) -> "UDFExpr":
"""Returns a list expression containing the keys of the map.

Example:
>>> from ray.data.expressions import col
>>> # Get keys from map column
>>> expr = col("headers").map.keys()

Returns:
A list expression containing the keys.
"""
return self._create_projection_udf(MapComponent.KEYS)

def values(self) -> "UDFExpr":
"""Returns a list expression containing the values of the map.

Example:
>>> from ray.data.expressions import col
>>> # Get values from map column
>>> expr = col("headers").map.values()

Returns:
A list expression containing the values.
"""
return self._create_projection_udf(MapComponent.VALUES)

def _create_projection_udf(self, component: MapComponent) -> "UDFExpr":
"""Helper to generate UDFs for map projections."""

return_dtype = DataType(object)
if self._expr.data_type.is_arrow_type():
arrow_type = self._expr.data_type.to_arrow_dtype()

is_physical_map = (
(
pyarrow.types.is_list(arrow_type)
or pyarrow.types.is_large_list(arrow_type)
)
and pyarrow.types.is_struct(arrow_type.value_type)
and arrow_type.value_type.num_fields >= 2
)

inner_arrow_type = None
if pyarrow.types.is_map(arrow_type):
inner_arrow_type = (
arrow_type.key_type
if component == MapComponent.KEYS
else arrow_type.item_type
)
elif is_physical_map:
idx = 0 if component == MapComponent.KEYS else 1
inner_arrow_type = arrow_type.value_type.field(idx).type

if inner_arrow_type:
return_dtype = DataType.list(DataType.from_arrow(inner_arrow_type))

@pyarrow_udf(return_dtype=return_dtype)
def _project_map(arr: pyarrow.Array) -> pyarrow.Array:
return _extract_map_component(arr, component)

return _project_map(self._expr)
128 changes: 128 additions & 0 deletions python/ray/data/tests/expressions/test_namespace_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import pandas as pd
import pyarrow as pa
import pytest
from packaging import version

import ray
from ray.data._internal.util import rows_same
from ray.data.expressions import col

pytestmark = pytest.mark.skipif(
version.parse(pa.__version__) < version.parse("19.0.0"),
reason="Namespace expressions tests require PyArrow >= 19.0",
)


@pytest.fixture
def map_dataset():
"""Fixture that creates a dataset backed by an Arrow MapArray column."""
map_items = [
{"attrs": {"color": "red", "size": "M"}},
{"attrs": {"brand": "Ray"}},
]
map_type = pa.map_(pa.string(), pa.string())
arrow_table = pa.table(
{"attrs": pa.array([row["attrs"] for row in map_items], type=map_type)}
)
return ray.data.from_arrow(arrow_table)


def _assert_result(result_df: pd.DataFrame, expected_df: pd.DataFrame, drop_cols: list):
"""Helper to drop columns and assert equality."""
result_df = result_df.drop(columns=drop_cols)
assert rows_same(result_df, expected_df)


class TestMapNamespace:
"""Tests for map namespace operations using the shared map_dataset fixture."""

def test_map_keys(self, map_dataset):
result = map_dataset.with_column("keys", col("attrs").map.keys()).to_pandas()
expected = pd.DataFrame({"keys": [["color", "size"], ["brand"]]})
_assert_result(result, expected, drop_cols=["attrs"])

def test_map_values(self, map_dataset):
result = map_dataset.with_column(
"values", col("attrs").map.values()
).to_pandas()
expected = pd.DataFrame({"values": [["red", "M"], ["Ray"]]})
_assert_result(result, expected, drop_cols=["attrs"])

def test_map_chaining(self, map_dataset):
# map.keys() returns a list, so .list.len() should apply
result = map_dataset.with_column(
"num_keys", col("attrs").map.keys().list.len()
).to_pandas()
expected = pd.DataFrame({"num_keys": [2, 1]})
_assert_result(result, expected, drop_cols=["attrs"])


def test_physical_map_extraction():
"""Test extraction works on List<Struct> (Physical Maps)."""
# Construct List<Struct<k, v>>
struct_type = pa.struct([pa.field("k", pa.string()), pa.field("v", pa.int64())])
list_type = pa.list_(struct_type)

data_py = [[{"k": "a", "v": 1}], [{"k": "b", "v": 2}]]
arrow_table = pa.Table.from_arrays(
[pa.array(data_py, type=list_type)], names=["data"]
)
ds = ray.data.from_arrow(arrow_table)

result = (
ds.with_column("keys", col("data").map.keys())
.with_column("values", col("data").map.values())
.to_pandas()
)

expected = pd.DataFrame(
{
"data": data_py,
"keys": [["a"], ["b"]],
"values": [[1], [2]],
}
)
assert rows_same(result, expected)


def test_map_sliced_offsets():
"""Test extraction works correctly on sliced Arrow arrays (offset > 0)."""
items = [{"m": {"id": i}} for i in range(10)]
map_type = pa.map_(pa.string(), pa.int64())
arrays = pa.array([row["m"] for row in items], type=map_type)
table = pa.Table.from_arrays([arrays], names=["m"])

# Force offsets by slicing the table before ingestion
sliced_table = table.slice(offset=7, length=3)
ds = ray.data.from_arrow(sliced_table)

result = ds.with_column("vals", col("m").map.values()).to_pandas()
expected = pd.DataFrame({"vals": [[7], [8], [9]]})
_assert_result(result, expected, drop_cols=["m"])


def test_map_nulls_and_empty():
"""Test handling of null maps and empty maps."""
items_data = [{"m": {"a": 1}}, {"m": {}}, {"m": None}]

map_type = pa.map_(pa.string(), pa.int64())
arrays = pa.array([row["m"] for row in items_data], type=map_type)
arrow_table = pa.Table.from_arrays([arrays], names=["m"])
ds = ray.data.from_arrow(arrow_table)

# Use take_all() to avoid pandas casting errors with mixed None/list types
rows = (
ds.with_column("keys", col("m").map.keys())
.with_column("values", col("m").map.values())
.take_all()
)

assert list(rows[0]["keys"]) == ["a"] and list(rows[0]["values"]) == [1]
assert len(rows[1]["keys"]) == 0 and len(rows[1]["values"]) == 0
assert rows[2]["keys"] is None and rows[2]["values"] is None
Comment on lines +119 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use rows_same

Copy link
Member Author

@ryankert01 ryankert01 Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

row_same operates on pandas that can't handle the mixed None/list column when converting. The to_pandas() path triggers TensorArray casting which fails on the mixed types. Let's keep it!

Although there's workaround, but is too complex for the context of this test:

    ctx = ray.data.context.DataContext.get_current()
    ctx.enable_tensor_extension_casting = False
    try:
        result = (
            ds.with_column("keys", col("m").map.keys())
            .with_column("values", col("m").map.values())
            .to_pandas()
        )
        expected = pd.DataFrame(
            {
                "keys": [["a"], [], None],
                "values": [[1], [], None],
            }
        )
        _assert_result(result, expected, drop_cols=["m"])
    finally:
        ctx.enable_tensor_extension_casting = True



if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))