11from resotoclient .models import Kind , Model , JsObject
2- from typing import Dict , List , Any , NamedTuple , Optional , Tuple
2+ from typing import Dict , List , Any , NamedTuple , Optional , Tuple , final , Literal
33import pyarrow as pa
4+ import pyarrow .csv as csv
45from cloud2sql .schema_utils import (
56 base_kinds ,
67 get_table_name ,
1112import pyarrow .parquet as pq
1213from pathlib import Path
1314from dataclasses import dataclass
15+ from abc import ABC
1416
1517
16- class ParquetModel :
18+ class ArrowModel :
1719 def __init__ (self , model : Model ):
1820 self .model = model
1921 self .table_kinds = [
@@ -23,7 +25,7 @@ def __init__(self, model: Model):
2325 ]
2426 self .schemas : Dict [str , pa .Schema ] = {}
2527
26- def _parquet_type (self , kind : str ) -> pa .lib .DataType :
28+ def _pyarrow_type (self , kind : str ) -> pa .lib .DataType :
2729 if kind .startswith ("dict" ) or "[]" in kind :
2830 return pa .string () # dicts and lists are converted to json strings
2931 elif kind == "int32" :
@@ -49,7 +51,7 @@ def table_schema(kind: Kind) -> None:
4951 schema = pa .schema (
5052 [
5153 pa .field ("_id" , pa .string ()),
52- * [pa .field (p .name , self ._parquet_type (p .kind )) for p in properties ],
54+ * [pa .field (p .name , self ._pyarrow_type (p .kind )) for p in properties ],
5355 ]
5456 )
5557 self .schemas [table_name ] = schema
@@ -90,41 +92,85 @@ class WriteResult(NamedTuple):
9092 table_name : str
9193
9294
95+ class FileWriter (ABC ):
96+ pass
97+
98+
99+ @final
100+ @dataclass (frozen = True )
101+ class Parquet (FileWriter ):
102+ parquet_writer : pq .ParquetWriter
103+
104+
105+ @final
106+ @dataclass (frozen = True )
107+ class CSV (FileWriter ):
108+ csv_writer : csv .CSVWriter
109+
110+
111+ @final
93112@dataclass
94- class ParquetBatch :
113+ class ArrowBatch :
95114 rows : List [Dict [str , Any ]]
96115 schema : pa .Schema
97- writer : pq .ParquetWriter
116+ writer : FileWriter
117+
118+
119+ def write_batch_to_file (batch : ArrowBatch ) -> ArrowBatch :
120+ pa_table = pa .Table .from_pylist (batch .rows , batch .schema )
121+ if isinstance (batch .writer , Parquet ):
122+ batch .writer .parquet_writer .write_table (pa_table )
123+ elif isinstance (batch .writer , CSV ):
124+ batch .writer .csv_writer .write_table (pa_table )
125+ else :
126+ raise ValueError (f"Unknown format { batch .writer } " )
127+ return ArrowBatch (rows = [], schema = batch .schema , writer = batch .writer )
128+
98129
130+ def close_writer (batch : ArrowBatch ) -> None :
131+ if isinstance (batch .writer , Parquet ):
132+ batch .writer .parquet_writer .close ()
133+ elif isinstance (batch .writer , CSV ):
134+ batch .writer .csv_writer .close ()
135+ else :
136+ raise ValueError (f"Unknown format { batch .writer } " )
99137
100- class ParquetWriter :
138+
139+ def new_writer (
140+ format : Literal ["parquet" , "csv" ], table_name : str , schema : pa .Schema , result_dir : Path
141+ ) -> FileWriter :
142+ def ensure_path (path : Path ) -> Path :
143+ path .mkdir (parents = True , exist_ok = True )
144+ return path
145+
146+ if format == "parquet" :
147+ return Parquet (pq .ParquetWriter (Path (ensure_path (result_dir ), f"{ table_name } .parquet" ), schema = schema ))
148+ elif format == "csv" :
149+ return CSV (csv .CSVWriter (Path (ensure_path (result_dir ), f"{ table_name } .csv" ), schema = schema ))
150+ else :
151+ raise ValueError (f"Unknown format { format } " )
152+
153+
154+ class ArrowWriter :
101155 def __init__ (
102- self ,
103- model : ParquetModel ,
104- result_directory : Path ,
105- rows_per_batch : int ,
156+ self , model : ArrowModel , result_directory : Path , rows_per_batch : int , output_format : Literal ["parquet" , "csv" ]
106157 ):
107158 self .model = model
108159 self .kind_by_id : Dict [str , str ] = {}
109- self .batches : Dict [str , ParquetBatch ] = {}
110- self .rows_per_batch = rows_per_batch
111- self .result_directory = result_directory
160+ self .batches : Dict [str , ArrowBatch ] = {}
161+ self .rows_per_batch : int = rows_per_batch
162+ self .result_directory : Path = result_directory
163+ self .output_format : Literal ["parquet" , "csv" ] = output_format
112164
113165 def insert_value (self , table_name : str , values : Any ) -> Optional [WriteResult ]:
114166 if self .model .schemas .get (table_name ):
115-
116- def ensure_path (path : Path ) -> Path :
117- path .mkdir (parents = True , exist_ok = True )
118- return path
119-
120167 batch = self .batches .get (
121168 table_name ,
122- ParquetBatch (
169+ ArrowBatch (
123170 [],
124171 self .model .schemas [table_name ],
125- pq .ParquetWriter (
126- Path (ensure_path (self .result_directory ), f"{ table_name } .parquet" ),
127- self .model .schemas [table_name ],
172+ new_writer (
173+ self .output_format , table_name , self .model .schemas [table_name ], self .result_directory
128174 ),
129175 ),
130176 )
@@ -134,12 +180,6 @@ def ensure_path(path: Path) -> Path:
134180 return WriteResult (table_name )
135181 return None
136182
137- def write_batch_bundle (self , batch : ParquetBatch ) -> None :
138- rows = batch .rows
139- batch .rows = []
140- pa_table = pa .Table .from_pylist (rows , batch .schema )
141- batch .writer .write_table (pa_table )
142-
143183 def insert_node (self , node : JsObject ) -> None :
144184 result = insert_node (
145185 node ,
@@ -151,9 +191,10 @@ def insert_node(self, node: JsObject) -> None:
151191 should_write_batch = result and len (self .batches [result .table_name ].rows ) > self .rows_per_batch
152192 if result and should_write_batch :
153193 batch = self .batches [result .table_name ]
154- self .write_batch_bundle (batch )
194+ self .batches [ result . table_name ] = write_batch_to_file (batch )
155195
156196 def close (self ) -> None :
157- for batch in self .batches .values ():
158- self .write_batch_bundle (batch )
159- batch .writer .close ()
197+ for table_name , batch in self .batches .items ():
198+ batch = write_batch_to_file (batch )
199+ self .batches [table_name ] = batch
200+ close_writer (batch )
0 commit comments