diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index d94f97f0..276d6971 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -29,12 +29,14 @@ import trino from tests.integration.conftest import trino_version from trino import constants +from trino.client import SegmentIterator from trino.dbapi import Cursor from trino.dbapi import DescribeOutput from trino.dbapi import TimeBoundLRUCache from trino.exceptions import NotSupportedError from trino.exceptions import TrinoQueryError from trino.exceptions import TrinoUserError +from trino.mapper import RowMapperFactory from trino.transaction import IsolationLevel @@ -1876,12 +1878,16 @@ def test_segments_cursor(trino_connection): start => 1, stop => 5, step => 1)) n""") - rows = cur.fetchall() - assert len(rows) > 0 - for spooled_data, spooled_segment in rows: - assert spooled_data.encoding == trino_connection._client_session.encoding - assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}" - assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}" + segments = cur.fetchall() + assert len(segments) > 0 + row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) + total = 0 + for segment in segments: + assert segment.encoding == trino_connection._client_session.encoding + assert isinstance(segment.segment.uri, str), f"Expected string for uri, got {segment.segment.uri}" + assert isinstance(segment.segment.ack_uri, str), f"Expected string for ack_uri, got {segment.segment.ack_uri}" + total += len(list(SegmentIterator(segment, row_mapper))) + assert total == 300875, f"Expected total rows 300875, got {total}" def get_cursor(legacy_prepared_statements, run_trino): diff --git a/trino/client.py b/trino/client.py index a122702f..637fe82b 100644 --- a/trino/client.py +++ b/trino/client.py @@ -89,7 +89,7 @@ "TrinoQuery", "TrinoRequest", "PROXIES", - "SpooledData", + "DecodableSegment", "SpooledSegment", "InlineSegment", "Segment" @@ -573,6 +573,15 @@ def http_headers(self) -> CaseInsensitiveDict[str]: return headers + def unauthenticated(self): + return TrinoRequest( + host=self._host, + port=self._port, + max_attempts=self.max_attempts, + request_timeout=self._request_timeout, + handle_retry=self._handle_retry, + client_session=ClientSession(user=self._client_session.user)) + @property def max_attempts(self) -> int: return self._max_attempts @@ -920,17 +929,18 @@ def fetch(self) -> List[Union[List[Any]], Any]: if isinstance(status.rows, dict): # spooling protocol rows = cast(_SpooledProtocolResponseTO, rows) - segments = self._to_segments(rows) + spooled = self._to_segments(rows) if self._fetch_mode == "segments": - return segments - return list(SegmentIterator(segments, self._row_mapper)) + return spooled + return list(SegmentIterator(spooled, self._row_mapper)) elif isinstance(status.rows, list): return self._row_mapper.map(rows) else: raise ValueError(f"Unexpected type: {type(status.rows)}") - def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: + def _to_segments(self, rows: _SpooledProtocolResponseTO) -> List[DecodableSegment]: encoding = rows["encoding"] + metadata = rows["metadata"] if "metadata" in rows else None segments = [] for segment in rows["segments"]: segment_type = segment["type"] @@ -939,11 +949,11 @@ def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData: segments.append(InlineSegment(inline_segment)) elif segment_type == SegmentType.SPOOLED: spooled_segment = cast(_SpooledSegmentTO, segment) - segments.append(SpooledSegment(spooled_segment, self._request)) + segments.append(SpooledSegment(spooled_segment, self._request.unauthenticated())) else: raise ValueError(f"Unsupported segment type: {segment_type}") - return SpooledData(encoding, segments) + return list(map(lambda segment: DecodableSegment(encoding, metadata, segment), segments)) def cancel(self) -> None: """Cancel the current query""" @@ -1024,6 +1034,7 @@ def _parse_retry_after_header(retry_after): # Trino Spooled protocol transfer objects class _SpooledProtocolResponseTO(TypedDict): encoding: Literal["json", "json+std", "json+lz4"] + metadata: _SegmentMetadataTO segments: List[_SegmentTO] @@ -1162,44 +1173,44 @@ def __repr__(self): ) -class SpooledData: +class DecodableSegment: """ Represents a collection of spooled segments of data, with an encoding format. Attributes: encoding (str): The encoding format of the spooled data. - segments (List[Segment]): The list of segments in the spooled data. + metadata (_SegmentMetadataTO): Metadata for all segments in the query + segment (Segment): The spooled segment data """ - def __init__(self, encoding: str, segments: List[Segment]) -> None: + def __init__(self, encoding: str, metadata: _SegmentMetadataTO, segment: Segment) -> None: self._encoding = encoding - self._segments = segments - self._segments_iterator = iter(segments) + self._metadata = metadata + self._segment = segment @property def encoding(self): return self._encoding @property - def segments(self): - return self._segments + def segment(self): + return self._segment - def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]: - return self - - def __next__(self) -> Tuple["SpooledData", "Segment"]: - return self, next(self._segments_iterator) + @property + def metadata(self): + return self._metadata def __repr__(self): - return (f"SpooledData(encoding={self._encoding}, segments={list(self._segments)})") + return (f"DecodableSegment(encoding={self._encoding}, metadata={self._metadata}, segment={self._segment})") class SegmentIterator: - def __init__(self, spooled_data: SpooledData, mapper: RowMapper) -> None: - self._segments = iter(spooled_data._segments) - self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(mapper).create(spooled_data.encoding)) + def __init__(self, segments: Union[DecodableSegment, List[DecodableSegment]], mapper: RowMapper) -> None: + self._segments = iter(segments if isinstance(segments, List) else [segments]) + self._mapper = mapper + self._decoder = None self._rows: Iterator[List[List[Any]]] = iter([]) self._finished = False - self._current_segment: Optional[Segment] = None + self._current_segment: Optional[DecodableSegment] = None def __iter__(self) -> Iterator[List[Any]]: return self @@ -1210,16 +1221,22 @@ def __next__(self) -> List[Any]: try: return next(self._rows) except StopIteration: - if self._current_segment and isinstance(self._current_segment, SpooledSegment): - self._current_segment.acknowledge() if self._finished: raise StopIteration self._load_next_segment() def _load_next_segment(self): try: - self._current_segment = segment = next(self._segments) - self._rows = iter(self._decoder.decode(segment)) + if self._current_segment: + segment = self._current_segment.segment + if isinstance(segment, SpooledSegment): + segment.acknowledge() + + self._current_segment = next(self._segments) + if self._decoder is None: + self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(self._mapper) + .create(self._current_segment.encoding)) + self._rows = iter(self._decoder.decode(self._current_segment.segment)) except StopIteration: self._finished = True