1- from typing import Dict , List , Any , NamedTuple , Optional , final , Literal
1+ from typing import Dict , List , Any , NamedTuple , Optional , final , Union
22import pyarrow .csv as csv
33from dataclasses import dataclass
44import dataclasses
88import json
99from pathlib import Path
1010from cloud2sql .arrow .model import ArrowModel
11+ from cloud2sql .arrow .config import (
12+ ArrowOutputConfig ,
13+ FileDestination ,
14+ CloudBucketDestination ,
15+ S3Bucket ,
16+ GCSBucket ,
17+ ArrowDestination ,
18+ )
1119from cloud2sql .schema_utils import insert_node
1220from resotoclient .models import JsObject
1321
22+ import boto3
23+ from google .cloud import storage
24+ import hashlib
25+
1426
1527class WriteResult (NamedTuple ):
1628 table_name : str
1729
1830
19- class FileWriter (ABC ):
20- pass
21-
22-
2331@final
2432@dataclass (frozen = True )
25- class Parquet ( FileWriter ) :
33+ class Parquet :
2634 parquet_writer : pq .ParquetWriter
2735
2836
2937@final
3038@dataclass (frozen = True )
31- class CSV ( FileWriter ) :
39+ class CSV :
3240 csv_writer : csv .CSVWriter
3341
3442
43+ FileWriterFormat = Union [Parquet , CSV ]
44+
45+
3546@final
36- @dataclass
47+ @dataclass (frozen = True )
48+ class FileWriter :
49+ path : Path
50+ format : FileWriterFormat
51+
52+
53+ @final
54+ @dataclass (frozen = True )
3755class ArrowBatch :
3856 table_name : str
3957 rows : List [Dict [str , Any ]]
4058 schema : pa .Schema
4159 writer : FileWriter
60+ destination : ArrowDestination
4261
4362
4463class ConversionTarget (ABC ):
@@ -157,60 +176,102 @@ def write_batch_to_file(batch: ArrowBatch) -> ArrowBatch:
157176 normalize (path , row )
158177
159178 pa_table = pa .Table .from_pylist (batch .rows , batch .schema )
160- if isinstance (batch .writer , Parquet ):
161- batch .writer .parquet_writer .write_table (pa_table )
162- elif isinstance (batch .writer , CSV ):
163- batch .writer .csv_writer .write_table (pa_table )
179+ if isinstance (batch .writer . format , Parquet ):
180+ batch .writer .format . parquet_writer .write_table (pa_table )
181+ elif isinstance (batch .writer . format , CSV ):
182+ batch .writer .format . csv_writer .write_table (pa_table )
164183 else :
165184 raise ValueError (f"Unknown format { batch .writer } " )
166- return ArrowBatch (table_name = batch .table_name , rows = [], schema = batch .schema , writer = batch .writer )
185+ return ArrowBatch (
186+ table_name = batch .table_name , rows = [], schema = batch .schema , writer = batch .writer , destination = batch .destination
187+ )
167188
168189
169190def close_writer (batch : ArrowBatch ) -> None :
170- if isinstance (batch .writer , Parquet ):
171- batch .writer .parquet_writer .close ()
172- elif isinstance (batch .writer , CSV ):
173- batch .writer .csv_writer .close ()
191+ def uploadToS3 (path : Path , bucket_name : str , region : str ) -> None :
192+ s3_client = boto3 .client ("s3" , region_name = region )
193+ s3_client .upload_file (str (path ), bucket_name , path .name )
194+
195+ def uploadToGCS (path : Path , bucket_name : str ) -> None :
196+ storage_client = storage .Client ()
197+ bucket = storage_client .bucket (bucket_name )
198+ blob = bucket .blob (path .name )
199+ blob .upload_from_filename (str (path ))
200+
201+ def maybeUpload () -> None :
202+ if isinstance (batch .destination , CloudBucketDestination ):
203+ destination = batch .destination
204+ if isinstance (destination .cloud_bucket , S3Bucket ):
205+ uploadToS3 (batch .writer .path , destination .bucket_name , destination .cloud_bucket .region )
206+ elif isinstance (destination .cloud_bucket , GCSBucket ):
207+ uploadToGCS (batch .writer .path , destination .bucket_name )
208+ else :
209+ raise ValueError (f"Unknown cloud bucket { destination .cloud_bucket } " )
210+
211+ if isinstance (batch .writer .format , Parquet ):
212+ batch .writer .format .parquet_writer .close ()
213+ maybeUpload ()
214+ elif isinstance (batch .writer .format , CSV ):
215+ batch .writer .format .csv_writer .close ()
216+ maybeUpload ()
174217 else :
175218 raise ValueError (f"Unknown format { batch .writer } " )
176219
177220
178- def new_writer (format : Literal [ "parquet" , "csv" ], table_name : str , schema : pa .Schema , result_dir : Path ) -> FileWriter :
221+ def new_writer (table_name : str , schema : pa .Schema , output_config : ArrowOutputConfig ) -> FileWriter :
179222 def ensure_path (path : Path ) -> Path :
180223 path .mkdir (parents = True , exist_ok = True )
181224 return path
182225
183- if format == "parquet" :
184- return Parquet (pq .ParquetWriter (Path (ensure_path (result_dir ), f"{ table_name } .parquet" ), schema = schema ))
185- elif format == "csv" :
186- return CSV (csv .CSVWriter (Path (ensure_path (result_dir ), f"{ table_name } .csv" ), schema = schema ))
226+ def sha (input : str ) -> str :
227+ h = hashlib .new ("sha256" )
228+ h .update (input .encode ("utf-8" ))
229+ return h .hexdigest ()
230+
231+ if isinstance (output_config .destination , FileDestination ):
232+ result_dir = ensure_path (output_config .destination .path )
187233 else :
188- raise ValueError (f"Unknown format { format } " )
234+ hashed_url = sha (output_config .destination .bucket_name )
235+ result_dir = ensure_path (Path (f"/tmp/cloud2sql-uploads/{ hashed_url } " ))
236+
237+ file_writer_format : Union [Parquet , CSV ]
238+ file_path : Path
239+ if output_config .format == "parquet" :
240+ file_path = Path (ensure_path (result_dir ), f"{ table_name } .parquet" )
241+ file_writer_format = Parquet (
242+ pq .ParquetWriter (file_path , schema = schema ),
243+ )
244+ elif output_config .format == "csv" :
245+ file_path = Path (ensure_path (result_dir ), f"{ table_name } .csv" )
246+ file_writer_format = CSV (
247+ csv .CSVWriter (file_path , schema = schema ),
248+ )
249+ else :
250+ raise ValueError (f"Unknown format { output_config .format } " )
251+
252+ return FileWriter (file_path , file_writer_format )
189253
190254
191255class ArrowWriter :
192- def __init__ (
193- self , model : ArrowModel , result_directory : Path , rows_per_batch : int , output_format : Literal ["parquet" , "csv" ]
194- ):
256+ def __init__ (self , model : ArrowModel , output_config : ArrowOutputConfig ):
195257 self .model = model
196258 self .kind_by_id : Dict [str , str ] = {}
197259 self .batches : Dict [str , ArrowBatch ] = {}
198- self .rows_per_batch : int = rows_per_batch
199- self .result_directory : Path = result_directory
200- self .output_format : Literal ["parquet" , "csv" ] = output_format
260+ self .output_config : ArrowOutputConfig = output_config
201261
202262 def insert_value (self , table_name : str , values : Any ) -> Optional [WriteResult ]:
203263 if self .model .schemas .get (table_name ):
204264 schema = self .model .schemas [table_name ]
205- batch = self .batches .get (
206- table_name ,
207- ArrowBatch (
265+ if table_name in self .batches :
266+ batch = self .batches [table_name ]
267+ else :
268+ batch = ArrowBatch (
208269 table_name ,
209270 [],
210271 schema ,
211- new_writer (self . output_format , table_name , schema , self .result_directory ),
212- ) ,
213- )
272+ new_writer (table_name , schema , self .output_config ),
273+ self . output_config . destination ,
274+ )
214275
215276 batch .rows .append (values )
216277 self .batches [table_name ] = batch
@@ -224,7 +285,7 @@ def insert_node(self, node: JsObject) -> None:
224285 self .insert_value ,
225286 with_tmp_prefix = False ,
226287 )
227- should_write_batch = result and len (self .batches [result .table_name ].rows ) > self .rows_per_batch
288+ should_write_batch = result and len (self .batches [result .table_name ].rows ) > self .output_config . batch_size
228289 if result and should_write_batch :
229290 batch = self .batches [result .table_name ]
230291 self .batches [result .table_name ] = write_batch_to_file (batch )
0 commit comments