Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions python/ray/data/_internal/datasource/kafka_datasink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Kafka datasink

This module provides a Kafka datasink implementation for Ray Data.

Requires:
- kafka-python: https://kafka-python.readthedocs.io/
"""

import json
from collections.abc import Callable, Iterable, Mapping
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from kafka import KafkaProducer
from kafka.errors import KafkaError, KafkaTimeoutError
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing runtime imports for Kafka classes

High Severity

KafkaProducer, KafkaError, and KafkaTimeoutError are imported only under TYPE_CHECKING, meaning they won't exist at runtime. When the write method executes, it will fail with a NameError because these names are undefined. The existing kafka_datasource.py correctly handles this by importing inside the function that uses them (e.g., from kafka import KafkaConsumer on line 336 of that file).

Additional Locations (2)

Fix in Cursor Fix in Web

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is valid


from ray.data import Datasink
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will give circular import error

Suggested change
from ray.data import Datasink
from ray.data.datasource.datasink import Datasink

from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor


class KafkaDatasink(Datasink):
"""
Ray Data sink for writing to Apache Kafka topics using kafka-python.

Writes blocks of data to Kafka with configurable serialization
and producer settings.
"""

def __init__(
self,
topic: str,
bootstrap_servers: str,
key_field: str | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Optional

key_serializer: str = "string",
value_serializer: str = "json",
producer_config: dict[str, Any] | None = None,
delivery_callback: Callable | None = None,
):
"""
Initialize Kafka sink.

Args:
topic: Kafka topic name
bootstrap_servers: Comma-separated Kafka broker addresses (e.g., 'localhost:9092')
key_field: Optional field name to use as message key
key_serializer: Key serialization format ('json', 'string', or 'bytes')
value_serializer: Value serialization format ('json', 'string', or 'bytes')
producer_config: Additional Kafka producer configuration (kafka-python format)
delivery_callback: Optional callback for delivery reports (called with metadata or exception)
"""
VALID_SERIALIZERS = {"json", "string", "bytes"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add _check_import here

if key_serializer not in VALID_SERIALIZERS:
raise ValueError(
f"key_serializer must be one of {VALID_SERIALIZERS}, "
f"got '{key_serializer}'"
)
if value_serializer not in VALID_SERIALIZERS:
raise ValueError(
f"value_serializer must be one of {VALID_SERIALIZERS}, "
f"got '{value_serializer}'"
)

self.topic = topic
self.bootstrap_servers = bootstrap_servers
self.key_field = key_field
self.key_serializer = key_serializer
self.value_serializer = value_serializer
self.producer_config = producer_config or {}
self.delivery_callback = delivery_callback

def _row_to_dict(self, row: Any) -> Any:
"""Convert row to dict if possible, otherwise return as-is."""
# 1. Fast path for standard dicts
if isinstance(row, dict):
return row

# 2. Ray's ArrowRow/PandasRow (and other Mappings)
# They usually implement .as_pydict() for efficient conversion
if hasattr(row, "as_pydict"):
return row.as_pydict()

# 3. Standard NamedTuple (no __dict__, but has _asdict)
if hasattr(row, "_asdict"):
return row._asdict()

# 4. General Mapping fallback (e.g. other dict-likes)
if isinstance(row, Mapping):
return dict(row)

# 5. Fallback: return as-is (e.g. primitives, strings, bytes)
return row

def _serialize_value(self, value: Any) -> bytes:
"""Serialize value based on configured format."""
# Convert ArrowRow to dict first
value = self._row_to_dict(value)

if self.value_serializer == "json":
return json.dumps(value).encode("utf-8")
elif self.value_serializer == "string":
return str(value).encode("utf-8")
else: # bytes
return value if isinstance(value, bytes) else str(value).encode("utf-8")

def _serialize_key(self, key: Any) -> bytes:
"""Serialize key based on configured format."""
if self.key_serializer == "json":
return json.dumps(key).encode("utf-8")
elif self.key_serializer == "string":
return str(key).encode("utf-8")
else: # bytes
return key if isinstance(key, bytes) else str(key).encode("utf-8")

def _extract_key(self, row: Any) -> bytes | None:
"""Extract and encode message key from row."""
# Convert ArrowRow to dict first
row_dict = self._row_to_dict(row)

key = None
if self.key_field and isinstance(row_dict, dict):
key_value = row_dict.get(self.key_field)
if key_value is not None:
key = self._serialize_key(key_value)
return key

def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> Any:
"""
Write blocks of data to Kafka.

Args:
blocks: Iterable of Ray data blocks
ctx: Ray data context

Returns:
Write statistics (total records written)
"""
# Create producer with config
producer = KafkaProducer(
bootstrap_servers=self.bootstrap_servers,
**self.producer_config,
)
total_records = 0
failed_messages = 0
futures = []

try:
for block in blocks:
block_accessor = BlockAccessor.for_block(block)

# Iterate through rows in block
for row in block_accessor.iter_rows(public_row_format=False):
# Extract key if specified
key = self._extract_key(row)

# Serialize value
value = self._serialize_value(row)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant row-to-dict conversion per row

Low Severity

_row_to_dict is called twice for each row during processing—once in _extract_key and once in _serialize_value. The PR discussion explicitly noted this redundancy and the author marked it as "addressed", but the duplicate conversion remains. The row could be converted once and passed to both methods.

Additional Locations (2)

Fix in Cursor Fix in Web

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is valid


# Produce to Kafka
try:
future = producer.send(self.topic, value=value, key=key)
except BufferError:
# Queue is full, flush and retry
producer.flush(timeout=10.0)
future = producer.send(self.topic, value=value, key=key)

# Add callback if provided
if self.delivery_callback:
future.add_callback(
lambda m: self.delivery_callback(metadata=m)
)
future.add_errback(
lambda e: self.delivery_callback(exception=e)
)
futures.append(future)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should have flush the future buffer when there are N items in buffer? If we have millions of rows this could accumulate.

total_records += 1

# Final flush to ensure all messages are sent
producer.flush(timeout=30.0)

# Check for any failed futures
for future in futures:
try:
future.get(timeout=0) # Non-blocking check since we already flushed
except Exception:
failed_messages += 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed messages silently counted instead of raising exception

Medium Severity

When message delivery fails, the exception is caught and the failure is silently counted in failed_messages. Unlike other datasinks (e.g., BigQuery which raises RuntimeError on write failure), no exception is raised. Since write_kafka returns None, users have no way to know messages failed. This inconsistency with other datasinks could cause silent data loss.

Fix in Cursor Fix in Web


except KafkaTimeoutError as e:
raise RuntimeError(f"Failed to write to Kafka: {e}") from e
except KafkaError as e:
raise RuntimeError(f"Failed to write to Kafka: {e}") from e
finally:
# Close the producer
producer.close(timeout=5.0)

return {"total_records": total_records, "failed_messages": failed_messages}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should log these

39 changes: 39 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ray.data._internal.datasource.iceberg_datasink import IcebergDatasink
from ray.data._internal.datasource.image_datasink import ImageDatasink
from ray.data._internal.datasource.json_datasink import JSONDatasink
from ray.data._internal.datasource.kafka_datasink import KafkaDatasink
from ray.data._internal.datasource.lance_datasink import LanceDatasink
from ray.data._internal.datasource.mongo_datasink import MongoDatasink
from ray.data._internal.datasource.numpy_datasink import NumpyDatasink
Expand Down Expand Up @@ -5312,6 +5313,44 @@ def write_lance(
concurrency=concurrency,
)

@ConsumptionAPI
def write_kafka(
self,
topic: str,
bootstrap_servers: str,
key_field: str | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Optional

key_serializer: str = "string",
value_serializer: str = "json",
producer_config: dict[str, Any] | None = None,
delivery_callback: Callable | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also accept and pass concurrency & ray_remote_args

) -> None:
"""
Convenience method to write Ray Dataset to Kafka.

Example:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Example:
Examples:

ConsumptionAPi use word "Examples" to match inorder to insert doc

>>> ds = ray.data.range(100)
>>> ds.write_kafka("my-topic", "localhost:9092")

Args:
topic: Kafka topic name
bootstrap_servers: Comma-separated Kafka broker addresses
key_field: Optional field name to use as message key
key_serializer: Key serialization format ('json', 'string', or 'bytes')
value_serializer: Value serialization format ('json', 'string', or 'bytes')
producer_config: Additional Kafka producer configuration (kafka-python format)
delivery_callback: Optional callback for delivery reports
"""
sink = KafkaDatasink(
topic=topic,
bootstrap_servers=bootstrap_servers,
key_field=key_field,
key_serializer=key_serializer,
value_serializer=value_serializer,
producer_config=producer_config,
delivery_callback=delivery_callback,
)
return self.write_datasink(sink)

@ConsumptionAPI(pattern="Time complexity:")
def write_datasink(
self,
Expand Down
Loading