Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
import trino
from tests.integration.conftest import trino_version
from trino import constants
from trino.client import SegmentIterator
from trino.dbapi import Cursor
from trino.dbapi import DescribeOutput
from trino.dbapi import TimeBoundLRUCache
from trino.exceptions import NotSupportedError
from trino.exceptions import TrinoQueryError
from trino.exceptions import TrinoUserError
from trino.mapper import RowMapperFactory
from trino.transaction import IsolationLevel


Expand Down Expand Up @@ -1876,12 +1878,16 @@ def test_segments_cursor(trino_connection):
start => 1,
stop => 5,
step => 1)) n""")
rows = cur.fetchall()
assert len(rows) > 0
for spooled_data, spooled_segment in rows:
assert spooled_data.encoding == trino_connection._client_session.encoding
assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}"
assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}"
segments = cur.fetchall()
assert len(segments) > 0
row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False)
total = 0
for segment in segments:
assert segment.encoding == trino_connection._client_session.encoding
assert isinstance(segment.segment.uri, str), f"Expected string for uri, got {segment.segment.uri}"
assert isinstance(segment.segment.ack_uri, str), f"Expected string for ack_uri, got {segment.segment.ack_uri}"
total += len(list(SegmentIterator(segment, row_mapper)))
assert total == 300875, f"Expected total rows 300875, got {total}"


def get_cursor(legacy_prepared_statements, run_trino):
Expand Down
73 changes: 45 additions & 28 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"TrinoQuery",
"TrinoRequest",
"PROXIES",
"SpooledData",
"DecodableSegment",
"SpooledSegment",
"InlineSegment",
"Segment"
Expand Down Expand Up @@ -573,6 +573,15 @@ def http_headers(self) -> CaseInsensitiveDict[str]:

return headers

def unauthenticated(self):
return TrinoRequest(
host=self._host,
port=self._port,
max_attempts=self.max_attempts,
request_timeout=self._request_timeout,
handle_retry=self._handle_retry,
client_session=ClientSession(user=self._client_session.user))

@property
def max_attempts(self) -> int:
return self._max_attempts
Expand Down Expand Up @@ -920,17 +929,18 @@ def fetch(self) -> List[Union[List[Any]], Any]:
if isinstance(status.rows, dict):
# spooling protocol
rows = cast(_SpooledProtocolResponseTO, rows)
segments = self._to_segments(rows)
spooled = self._to_segments(rows)
if self._fetch_mode == "segments":
return segments
return list(SegmentIterator(segments, self._row_mapper))
return spooled
return list(SegmentIterator(spooled, self._row_mapper))
elif isinstance(status.rows, list):
return self._row_mapper.map(rows)
else:
raise ValueError(f"Unexpected type: {type(status.rows)}")

def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData:
def _to_segments(self, rows: _SpooledProtocolResponseTO) -> List[DecodableSegment]:
encoding = rows["encoding"]
metadata = rows["metadata"] if "metadata" in rows else None
segments = []
for segment in rows["segments"]:
segment_type = segment["type"]
Expand All @@ -939,11 +949,11 @@ def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData:
segments.append(InlineSegment(inline_segment))
elif segment_type == SegmentType.SPOOLED:
spooled_segment = cast(_SpooledSegmentTO, segment)
segments.append(SpooledSegment(spooled_segment, self._request))
segments.append(SpooledSegment(spooled_segment, self._request.unauthenticated()))
else:
raise ValueError(f"Unsupported segment type: {segment_type}")

return SpooledData(encoding, segments)
return list(map(lambda segment: DecodableSegment(encoding, metadata, segment), segments))

def cancel(self) -> None:
"""Cancel the current query"""
Expand Down Expand Up @@ -1024,6 +1034,7 @@ def _parse_retry_after_header(retry_after):
# Trino Spooled protocol transfer objects
class _SpooledProtocolResponseTO(TypedDict):
encoding: Literal["json", "json+std", "json+lz4"]
metadata: _SegmentMetadataTO
segments: List[_SegmentTO]


Expand Down Expand Up @@ -1162,44 +1173,44 @@ def __repr__(self):
)


class SpooledData:
class DecodableSegment:
"""
Represents a collection of spooled segments of data, with an encoding format.

Attributes:
encoding (str): The encoding format of the spooled data.
segments (List[Segment]): The list of segments in the spooled data.
metadata (_SegmentMetadataTO): Metadata for all segments in the query
segment (Segment): The spooled segment data
"""
def __init__(self, encoding: str, segments: List[Segment]) -> None:
def __init__(self, encoding: str, metadata: _SegmentMetadataTO, segment: Segment) -> None:
self._encoding = encoding
self._segments = segments
self._segments_iterator = iter(segments)
self._metadata = metadata
self._segment = segment

@property
def encoding(self):
return self._encoding

@property
def segments(self):
return self._segments
def segment(self):
return self._segment

def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]:
return self

def __next__(self) -> Tuple["SpooledData", "Segment"]:
return self, next(self._segments_iterator)
@property
def metadata(self):
return self._metadata

def __repr__(self):
return (f"SpooledData(encoding={self._encoding}, segments={list(self._segments)})")
return (f"DecodableSegment(encoding={self._encoding}, metadata={self._metadata}, segment={self._segment})")


class SegmentIterator:
def __init__(self, spooled_data: SpooledData, mapper: RowMapper) -> None:
self._segments = iter(spooled_data._segments)
self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(mapper).create(spooled_data.encoding))
def __init__(self, segments: Union[DecodableSegment, List[DecodableSegment]], mapper: RowMapper) -> None:
self._segments = iter(segments if isinstance(segments, List) else [segments])
self._mapper = mapper
self._decoder = None
self._rows: Iterator[List[List[Any]]] = iter([])
self._finished = False
self._current_segment: Optional[Segment] = None
self._current_segment: Optional[DecodableSegment] = None

def __iter__(self) -> Iterator[List[Any]]:
return self
Expand All @@ -1210,16 +1221,22 @@ def __next__(self) -> List[Any]:
try:
return next(self._rows)
except StopIteration:
if self._current_segment and isinstance(self._current_segment, SpooledSegment):
self._current_segment.acknowledge()
if self._finished:
raise StopIteration
self._load_next_segment()

def _load_next_segment(self):
try:
self._current_segment = segment = next(self._segments)
self._rows = iter(self._decoder.decode(segment))
if self._current_segment:
segment = self._current_segment.segment
if isinstance(segment, SpooledSegment):
segment.acknowledge()

self._current_segment = next(self._segments)
if self._decoder is None:
self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(self._mapper)
.create(self._current_segment.encoding))
self._rows = iter(self._decoder.decode(self._current_segment.segment))
except StopIteration:
self._finished = True

Expand Down