Skip to content

Commit 099bb33

Browse files
committed
Rename SpooledData to DecodableSegment and remove Tuple iterator
This makes the API much easier to consume in the `segments` cursor style: ``` cur = conn.cursor('segment') cur.execute(sql) segments = cur.fetchall() total_row_count = 0 row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) for segment in segments: rows = list(SegmentIterator(segment, row_mapper)) print ("rows length is " + str(len(rows)) + " " + segment.encoding) total_row_count += len(rows) print(total_row_count) ``` This will work as well: ``` cur = conn.cursor('segment') cur.execute(sql) segments = cur.fetchall() total_row_count = 0 row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) rows = list(SegmentIterator(segments, row_mapper)) print ("rows length is " + str(len(rows))) total_row_count += len(rows) print(total_row_count) ```
1 parent 00dfb5e commit 099bb33

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
import trino
3030
from tests.integration.conftest import trino_version
3131
from trino import constants
32+
from trino.client import SegmentIterator
3233
from trino.dbapi import Cursor
3334
from trino.dbapi import DescribeOutput
3435
from trino.dbapi import TimeBoundLRUCache
3536
from trino.exceptions import NotSupportedError
3637
from trino.exceptions import TrinoQueryError
3738
from trino.exceptions import TrinoUserError
39+
from trino.mapper import RowMapperFactory
3840
from trino.transaction import IsolationLevel
3941

4042

@@ -1876,12 +1878,16 @@ def test_segments_cursor(trino_connection):
18761878
start => 1,
18771879
stop => 5,
18781880
step => 1)) n""")
1879-
rows = cur.fetchall()
1880-
assert len(rows) > 0
1881-
for spooled_data, spooled_segment in rows:
1882-
assert spooled_data.encoding == trino_connection._client_session.encoding
1883-
assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}"
1884-
assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}"
1881+
segments = cur.fetchall()
1882+
assert len(segments) > 0
1883+
row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False)
1884+
total = 0
1885+
for segment in segments:
1886+
assert segment.encoding == trino_connection._client_session.encoding
1887+
assert isinstance(segment.segment.uri, str), f"Expected string for uri, got {segment.segment.uri}"
1888+
assert isinstance(segment.segment.ack_uri, str), f"Expected string for ack_uri, got {segment.segment.ack_uri}"
1889+
total += len(list(SegmentIterator(segment, row_mapper)))
1890+
assert total == 300875, f"Expected total rows 300875, got {total}"
18851891

18861892

18871893
def get_cursor(legacy_prepared_statements, run_trino):

trino/client.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
"TrinoQuery",
9090
"TrinoRequest",
9191
"PROXIES",
92-
"SpooledData",
92+
"DecodableSegment",
9393
"SpooledSegment",
9494
"InlineSegment",
9595
"Segment"
@@ -920,16 +920,16 @@ def fetch(self) -> List[Union[List[Any]], Any]:
920920
if isinstance(status.rows, dict):
921921
# spooling protocol
922922
rows = cast(_SpooledProtocolResponseTO, rows)
923-
segments = self._to_segments(rows)
923+
spooled = self._to_segments(rows)
924924
if self._fetch_mode == "segments":
925-
return segments
926-
return list(SegmentIterator(segments, self._row_mapper))
925+
return spooled
926+
return list(SegmentIterator(spooled, self._row_mapper))
927927
elif isinstance(status.rows, list):
928928
return self._row_mapper.map(rows)
929929
else:
930930
raise ValueError(f"Unexpected type: {type(status.rows)}")
931931

932-
def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData:
932+
def _to_segments(self, rows: _SpooledProtocolResponseTO) -> List[DecodableSegment]:
933933
encoding = rows["encoding"]
934934
metadata = rows["metadata"] if "metadata" in rows else None
935935
segments = []
@@ -944,7 +944,7 @@ def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData:
944944
else:
945945
raise ValueError(f"Unsupported segment type: {segment_type}")
946946

947-
return SpooledData(encoding, metadata, segments)
947+
return list(map(lambda segment: DecodableSegment(encoding, metadata, segment), segments))
948948

949949
def cancel(self) -> None:
950950
"""Cancel the current query"""
@@ -1164,46 +1164,44 @@ def __repr__(self):
11641164
)
11651165

11661166

1167-
class SpooledData:
1167+
class DecodableSegment:
11681168
"""
11691169
Represents a collection of spooled segments of data, with an encoding format.
11701170
11711171
Attributes:
11721172
encoding (str): The encoding format of the spooled data.
1173-
metadata (_SegmentMetadataTO): Metadata for all segments
1174-
segments (List[Segment]): The list of segments in the spooled data.
1173+
metadata (_SegmentMetadataTO): Metadata for all segments in the query
1174+
segment (Segment): The spooled segment data
11751175
"""
1176-
def __init__(self, encoding: str, metadata: _SegmentMetadataTO, segments: List[Segment]) -> None:
1176+
def __init__(self, encoding: str, metadata: _SegmentMetadataTO, segment: Segment) -> None:
11771177
self._encoding = encoding
11781178
self._metadata = metadata
1179-
self._segments = segments
1180-
self._segments_iterator = iter(segments)
1179+
self._segment = segment
11811180

11821181
@property
11831182
def encoding(self):
11841183
return self._encoding
11851184

11861185
@property
1187-
def segments(self):
1188-
return self._segments
1189-
1190-
def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]:
1191-
return self
1186+
def segment(self):
1187+
return self._segment
11921188

1193-
def __next__(self) -> Tuple["SpooledData", "Segment"]:
1194-
return self, next(self._segments_iterator)
1189+
@property
1190+
def metadata(self):
1191+
return self._metadata
11951192

11961193
def __repr__(self):
1197-
return (f"SpooledData(encoding={self._encoding}, metadata={self._metadata}, segments={list(self._segments)})")
1194+
return (f"DecodableSegment(encoding={self._encoding}, metadata={self._metadata}, segment={self._segment})")
11981195

11991196

12001197
class SegmentIterator:
1201-
def __init__(self, spooled_data: SpooledData, mapper: RowMapper) -> None:
1202-
self._segments = iter(spooled_data._segments)
1203-
self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(mapper).create(spooled_data.encoding))
1198+
def __init__(self, segments: Union[DecodableSegment, List[DecodableSegment]], mapper: RowMapper) -> None:
1199+
self._segments = iter(segments if isinstance(segments, List) else [segments])
1200+
self._mapper = mapper
1201+
self._decoder = None
12041202
self._rows: Iterator[List[List[Any]]] = iter([])
12051203
self._finished = False
1206-
self._current_segment: Optional[Segment] = None
1204+
self._current_segment: Optional[DecodableSegment] = None
12071205

12081206
def __iter__(self) -> Iterator[List[Any]]:
12091207
return self
@@ -1214,16 +1212,22 @@ def __next__(self) -> List[Any]:
12141212
try:
12151213
return next(self._rows)
12161214
except StopIteration:
1217-
if self._current_segment and isinstance(self._current_segment, SpooledSegment):
1218-
self._current_segment.acknowledge()
12191215
if self._finished:
12201216
raise StopIteration
12211217
self._load_next_segment()
12221218

12231219
def _load_next_segment(self):
12241220
try:
1225-
self._current_segment = segment = next(self._segments)
1226-
self._rows = iter(self._decoder.decode(segment))
1221+
if self._current_segment:
1222+
segment = self._current_segment.segment
1223+
if isinstance(segment, SpooledSegment):
1224+
segment.acknowledge()
1225+
1226+
self._current_segment = next(self._segments)
1227+
if self._decoder is None:
1228+
self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(self._mapper)
1229+
.create(self._current_segment.encoding))
1230+
self._rows = iter(self._decoder.decode(self._current_segment.segment))
12271231
except StopIteration:
12281232
self._finished = True
12291233

0 commit comments

Comments
 (0)