diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index d94f97f0..d7e8f7ea 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1876,12 +1876,18 @@ def test_segments_cursor(trino_connection): start => 1, stop => 5, step => 1)) n""") - rows = cur.fetchall() - assert len(rows) > 0 - for spooled_data, spooled_segment in rows: + segments = cur.fetchall() + assert len(segments) > 0 + row_mapper = trino.mapper.RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) + total = 0 + for spooled_data in segments: + assert len(spooled_data.segments) == 1, "Expected SpooledData to contain a single segment" + segment = spooled_data.segments[0] assert spooled_data.encoding == trino_connection._client_session.encoding - assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}" - assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}" + assert isinstance(segment.uri, str), f"Expected string for uri, got {segment.uri}" + assert isinstance(segment.ack_uri, str), f"Expected string for ack_uri, got {segment.ack_uri}" + total += len(list(trino.client.SegmentIterator(spooled_data, row_mapper))) + assert total == 300875, f"Expected total rows 300875, got {total}" def get_cursor(legacy_prepared_statements, run_trino): diff --git a/trino/client.py b/trino/client.py index a122702f..92de2dd0 100644 --- a/trino/client.py +++ b/trino/client.py @@ -1187,7 +1187,8 @@ def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]: return self def __next__(self) -> Tuple["SpooledData", "Segment"]: - return self, next(self._segments_iterator) + segment = next(self._segments_iterator) + return SpooledData(self._encoding, [segment]), segment def __repr__(self): return (f"SpooledData(encoding={self._encoding}, segments={list(self._segments)})") diff --git a/trino/dbapi.py b/trino/dbapi.py index dee7cdb7..00adeb8d 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -742,7 +742,7 @@ def execute(self, operation, params=None): self._query = trino.client.TrinoQuery(self._request, query=operation, legacy_primitive_types=self._legacy_primitive_types, fetch_mode="segments") - self._iterator = iter(self._query.execute()) + self._iterator = map(lambda tuple: tuple[0], iter(self._query.execute())) return self