Skip to content

Commit c517804

Browse files
committed
Add write_zipped_dataframes for S3ResultsWriter
1 parent 80dd301 commit c517804

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

sdgym/result_writer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,16 @@ def write_yaml(self, data, file_path, append=False):
154154
run_data.update(data)
155155
new_content = yaml.dump(run_data)
156156
self.s3_client.put_object(Body=new_content.encode(), Bucket=bucket, Key=key)
157+
158+
def write_zipped_dataframes(self, data, file_path, index=False):
159+
"""Write a dictionary of DataFrames to a ZIP file in S3."""
160+
bucket, key = parse_s3_path(file_path)
161+
zip_buffer = io.BytesIO()
162+
with zipfile.ZipFile(zip_buffer, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
163+
for table_name, table in data.items():
164+
csv_buf = io.StringIO()
165+
table.to_csv(csv_buf, index=index)
166+
zf.writestr(f'{table_name}.csv', csv_buf.getvalue())
167+
168+
zip_buffer.seek(0)
169+
self.s3_client.upload_fileobj(Body=zip_buffer, Bucket=bucket, Key=key)

tests/unit/test_result_writer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import zipfile
12
from unittest.mock import Mock, patch
23

34
import cloudpickle
@@ -224,3 +225,32 @@ def test_write_yaml_append(self, mockparse_s3_path):
224225
Bucket='bucket_name',
225226
Key='key_prefix/test_data.yaml',
226227
)
228+
229+
@patch('sdgym.result_writer.parse_s3_path')
230+
@patch('sdgym.result_writer.io.StringIO')
231+
def write_zipped_dataframes(self, mock_string_io, mockparse_s3_path):
232+
"""Test the `write_zipped_dataframes` method."""
233+
# Setup
234+
mock_s3_client = Mock()
235+
mockparse_s3_path.return_value = ('bucket_name', 'key_prefix/test_data.zip')
236+
result_writer = S3ResultsWriter(mock_s3_client)
237+
df1 = pd.DataFrame({'col1': [1, 2]})
238+
df2 = pd.DataFrame({'colA': ['x', 'y']})
239+
df1.to_csv = Mock(return_value='csv1')
240+
df2.to_csv = Mock(return_value='csv2')
241+
data = {'table1': df1, 'table2': df2}
242+
mock_string_io.side_effect = ['buffer1', 'buffer2']
243+
244+
# Run
245+
result_writer.write_zipped_dataframes(data, 'test_data.zip')
246+
247+
# Assert
248+
mockparse_s3_path.assert_called_once_with('test_data.zip')
249+
mock_s3_client.upload_fileobj.assert_called_once()
250+
df1.to_csv.assert_called_once_with('buffer1', index=False)
251+
df2.to_csv.assert_called_once_with('buffer2', index=False)
252+
args, _ = mock_s3_client.upload_fileobj.call_args
253+
uploaded_buffer = args[0]
254+
uploaded_buffer.seek(0)
255+
with zipfile.ZipFile(uploaded_buffer, 'r') as zf:
256+
assert set(zf.namelist()) == {'table1.csv', 'table2.csv'}

0 commit comments

Comments
 (0)