Skip to content

Commit 279526a

Browse files
authored
Add csv export support (#17)
1 parent 0c38f81 commit 279526a

File tree

6 files changed

+128
-53
lines changed

6 files changed

+128
-53
lines changed

README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,23 @@ destinations:
9696

9797
```
9898
destinations:
99-
parquet:
100-
path: /path/to/parquet/files
99+
file:
100+
path: /where/to/write/parquet/files/
101+
format: parquet
101102
batch_size: 100_000
102103
```
103104

105+
#### CSV
106+
107+
```
108+
destinations:
109+
file:
110+
path: /where/to/write/to/csv/files/
111+
format: csv
112+
batch_size: 100_000
113+
```
114+
115+
104116
#### My database is not listed here
105117

106118
cloud2sql uses SQLAlchemy to connect to the database. If your database is not listed here, you can check if it is supported in [SQLAlchemy Dialects](https://docs.sqlalchemy.org/en/20/dialects/index.html).

cloud2sql/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main() -> None:
5959
sender = NoEventSender() if args.analytics_opt_out else PosthogEventSender()
6060
config = configure(args.config)
6161
engine = None
62-
if next(iter(config["destinations"].keys()), None) == "parquet":
62+
if next(iter(config["destinations"].keys()), None) == "file":
6363
check_parquet_driver()
6464
else:
6565
engine = create_engine(db_string_from_config(config))

cloud2sql/collect_plugins.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from queue import Queue
99
from threading import Event
1010
from time import sleep
11-
from typing import Dict, Optional, List, Any, Tuple, Set
11+
from typing import Dict, Optional, List, Any, Tuple, Set, Literal
1212
from pathlib import Path
13+
from dataclasses import dataclass
1314

1415
import pkg_resources
1516
import yaml
@@ -33,7 +34,7 @@
3334
from cloud2sql.sql import SqlUpdater, sql_updater
3435

3536
try:
36-
from cloud2sql.parquet import ParquetModel, ParquetWriter
37+
from cloud2sql.parquet import ArrowModel, ArrowWriter
3738
except ImportError:
3839
pass
3940

@@ -59,7 +60,18 @@ def collectors(raw_config: Json, feedback: CoreFeedback) -> Dict[str, BaseCollec
5960
return result
6061

6162

63+
@dataclass(frozen=True)
64+
class FileDestination:
65+
path: Path
66+
format: Literal["parquet", "csv"]
67+
batch_size: int
68+
69+
6270
def configure(path_to_config: Optional[str]) -> Json:
71+
def require(key: str, obj: Json, msg: str):
72+
if key not in obj:
73+
raise ValueError(msg)
74+
6375
config = {}
6476
if path_to_config:
6577
with open(path_to_config) as f:
@@ -70,6 +82,16 @@ def configure(path_to_config: Optional[str]) -> Json:
7082
if "destinations" not in config:
7183
raise ValueError("No destinations configured")
7284

85+
if "file" in (config.get("destinations", {}) or {}):
86+
file_dest = config["destinations"]["file"]
87+
require("format", file_dest, "No format configured for file destination")
88+
if not file_dest["format"] in ["parquet", "csv"]:
89+
raise ValueError("Format must be either parquet or csv")
90+
require("path", file_dest, "No path configured for file destination")
91+
config["destinations"]["file"] = FileDestination(
92+
Path(file_dest["path"]), file_dest["format"], int(file_dest.get("batch_size", 100_000))
93+
)
94+
7395
return config
7496

7597

@@ -79,7 +101,9 @@ def collect(
79101
if engine:
80102
return collect_sql(collector, engine, feedback, args)
81103
else:
82-
return collect_parquet(collector, feedback, config)
104+
if "file" not in config["destinations"]:
105+
raise ValueError("No file destination configured")
106+
return collect_to_file(collector, feedback, config["destinations"]["file"])
83107

84108

85109
def prepare_node(node: BaseResource, collector: BaseCollectorPlugin) -> Json:
@@ -95,13 +119,15 @@ def prepare_node(node: BaseResource, collector: BaseCollectorPlugin) -> Json:
95119
return exported
96120

97121

98-
def collect_parquet(collector: BaseCollectorPlugin, feedback: CoreFeedback, config: Json) -> Tuple[str, int, int]:
122+
def collect_to_file(
123+
collector: BaseCollectorPlugin, feedback: CoreFeedback, config: FileDestination
124+
) -> Tuple[str, int, int]:
99125
# collect cloud data
100126
feedback.progress_done(collector.cloud, 0, 1)
101127
collector.collect()
102128
# read the kinds created from this collector
103129
kinds = [from_json(m, Kind) for m in collector.graph.export_model(walk_subclasses=False)]
104-
model = ParquetModel(Model({k.fqn: k for k in kinds}))
130+
model = ArrowModel(Model({k.fqn: k for k in kinds}))
105131
node_edge_count = len(collector.graph.nodes) + len(collector.graph.edges)
106132
ne_current = 0
107133
progress_update = node_edge_count // 100
@@ -115,11 +141,7 @@ def collect_parquet(collector: BaseCollectorPlugin, feedback: CoreFeedback, conf
115141
# create the ddl metadata from the kinds
116142
model.create_schema(list(edges_by_kind))
117143
# ingest the data
118-
parquet_conf = config.get("destinations", {}).get("parquet")
119-
assert parquet_conf
120-
parquet_path = Path(parquet_conf["path"])
121-
parquet_batch_size = int(parquet_conf["batch_size"])
122-
writer = ParquetWriter(model, parquet_path, parquet_batch_size)
144+
writer = ArrowWriter(model, config.path, config.batch_size, config.format)
123145
node: BaseResource
124146
for node in sorted(collector.graph.nodes, key=lambda n: n.kind):
125147
exported = prepare_node(node, collector)
@@ -214,7 +236,7 @@ def collect_from_plugins(engine: Optional[Engine], args: Namespace, sender: Anal
214236
raw_config = configure(args.config)
215237
sources = raw_config["sources"]
216238
all_collectors = collectors(sources, feedback)
217-
engine_name = engine.dialect.name if engine else "parquet"
239+
engine_name = engine.dialect.name if engine else "file"
218240
analytics = {"total": len(all_collectors), "engine": engine_name} | {name: 1 for name in all_collectors}
219241
end = Event()
220242
with ThreadPoolExecutor(max_workers=4) as executor:

cloud2sql/parquet.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from resotoclient.models import Kind, Model, JsObject
2-
from typing import Dict, List, Any, NamedTuple, Optional, Tuple
2+
from typing import Dict, List, Any, NamedTuple, Optional, Tuple, final, Literal
33
import pyarrow as pa
4+
import pyarrow.csv as csv
45
from cloud2sql.schema_utils import (
56
base_kinds,
67
get_table_name,
@@ -11,9 +12,10 @@
1112
import pyarrow.parquet as pq
1213
from pathlib import Path
1314
from dataclasses import dataclass
15+
from abc import ABC
1416

1517

16-
class ParquetModel:
18+
class ArrowModel:
1719
def __init__(self, model: Model):
1820
self.model = model
1921
self.table_kinds = [
@@ -23,7 +25,7 @@ def __init__(self, model: Model):
2325
]
2426
self.schemas: Dict[str, pa.Schema] = {}
2527

26-
def _parquet_type(self, kind: str) -> pa.lib.DataType:
28+
def _pyarrow_type(self, kind: str) -> pa.lib.DataType:
2729
if kind.startswith("dict") or "[]" in kind:
2830
return pa.string() # dicts and lists are converted to json strings
2931
elif kind == "int32":
@@ -49,7 +51,7 @@ def table_schema(kind: Kind) -> None:
4951
schema = pa.schema(
5052
[
5153
pa.field("_id", pa.string()),
52-
*[pa.field(p.name, self._parquet_type(p.kind)) for p in properties],
54+
*[pa.field(p.name, self._pyarrow_type(p.kind)) for p in properties],
5355
]
5456
)
5557
self.schemas[table_name] = schema
@@ -90,41 +92,85 @@ class WriteResult(NamedTuple):
9092
table_name: str
9193

9294

95+
class FileWriter(ABC):
96+
pass
97+
98+
99+
@final
100+
@dataclass(frozen=True)
101+
class Parquet(FileWriter):
102+
parquet_writer: pq.ParquetWriter
103+
104+
105+
@final
106+
@dataclass(frozen=True)
107+
class CSV(FileWriter):
108+
csv_writer: csv.CSVWriter
109+
110+
111+
@final
93112
@dataclass
94-
class ParquetBatch:
113+
class ArrowBatch:
95114
rows: List[Dict[str, Any]]
96115
schema: pa.Schema
97-
writer: pq.ParquetWriter
116+
writer: FileWriter
117+
118+
119+
def write_batch_to_file(batch: ArrowBatch) -> ArrowBatch:
120+
pa_table = pa.Table.from_pylist(batch.rows, batch.schema)
121+
if isinstance(batch.writer, Parquet):
122+
batch.writer.parquet_writer.write_table(pa_table)
123+
elif isinstance(batch.writer, CSV):
124+
batch.writer.csv_writer.write_table(pa_table)
125+
else:
126+
raise ValueError(f"Unknown format {batch.writer}")
127+
return ArrowBatch(rows=[], schema=batch.schema, writer=batch.writer)
128+
98129

130+
def close_writer(batch: ArrowBatch) -> None:
131+
if isinstance(batch.writer, Parquet):
132+
batch.writer.parquet_writer.close()
133+
elif isinstance(batch.writer, CSV):
134+
batch.writer.csv_writer.close()
135+
else:
136+
raise ValueError(f"Unknown format {batch.writer}")
99137

100-
class ParquetWriter:
138+
139+
def new_writer(
140+
format: Literal["parquet", "csv"], table_name: str, schema: pa.Schema, result_dir: Path
141+
) -> FileWriter:
142+
def ensure_path(path: Path) -> Path:
143+
path.mkdir(parents=True, exist_ok=True)
144+
return path
145+
146+
if format == "parquet":
147+
return Parquet(pq.ParquetWriter(Path(ensure_path(result_dir), f"{table_name}.parquet"), schema=schema))
148+
elif format == "csv":
149+
return CSV(csv.CSVWriter(Path(ensure_path(result_dir), f"{table_name}.csv"), schema=schema))
150+
else:
151+
raise ValueError(f"Unknown format {format}")
152+
153+
154+
class ArrowWriter:
101155
def __init__(
102-
self,
103-
model: ParquetModel,
104-
result_directory: Path,
105-
rows_per_batch: int,
156+
self, model: ArrowModel, result_directory: Path, rows_per_batch: int, output_format: Literal["parquet", "csv"]
106157
):
107158
self.model = model
108159
self.kind_by_id: Dict[str, str] = {}
109-
self.batches: Dict[str, ParquetBatch] = {}
110-
self.rows_per_batch = rows_per_batch
111-
self.result_directory = result_directory
160+
self.batches: Dict[str, ArrowBatch] = {}
161+
self.rows_per_batch: int = rows_per_batch
162+
self.result_directory: Path = result_directory
163+
self.output_format: Literal["parquet", "csv"] = output_format
112164

113165
def insert_value(self, table_name: str, values: Any) -> Optional[WriteResult]:
114166
if self.model.schemas.get(table_name):
115-
116-
def ensure_path(path: Path) -> Path:
117-
path.mkdir(parents=True, exist_ok=True)
118-
return path
119-
120167
batch = self.batches.get(
121168
table_name,
122-
ParquetBatch(
169+
ArrowBatch(
123170
[],
124171
self.model.schemas[table_name],
125-
pq.ParquetWriter(
126-
Path(ensure_path(self.result_directory), f"{table_name}.parquet"),
127-
self.model.schemas[table_name],
172+
new_writer(
173+
self.output_format, table_name, self.model.schemas[table_name], self.result_directory
128174
),
129175
),
130176
)
@@ -134,12 +180,6 @@ def ensure_path(path: Path) -> Path:
134180
return WriteResult(table_name)
135181
return None
136182

137-
def write_batch_bundle(self, batch: ParquetBatch) -> None:
138-
rows = batch.rows
139-
batch.rows = []
140-
pa_table = pa.Table.from_pylist(rows, batch.schema)
141-
batch.writer.write_table(pa_table)
142-
143183
def insert_node(self, node: JsObject) -> None:
144184
result = insert_node(
145185
node,
@@ -151,9 +191,10 @@ def insert_node(self, node: JsObject) -> None:
151191
should_write_batch = result and len(self.batches[result.table_name].rows) > self.rows_per_batch
152192
if result and should_write_batch:
153193
batch = self.batches[result.table_name]
154-
self.write_batch_bundle(batch)
194+
self.batches[result.table_name] = write_batch_to_file(batch)
155195

156196
def close(self) -> None:
157-
for batch in self.batches.values():
158-
self.write_batch_bundle(batch)
159-
batch.writer.close()
197+
for table_name, batch in self.batches.items():
198+
batch = write_batch_to_file(batch)
199+
self.batches[table_name] = batch
200+
close_writer(batch)

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sqlalchemy.engine import create_engine, Engine
1010

1111
from cloud2sql.sql import SqlDefaultUpdater
12-
from cloud2sql.parquet import ParquetModel, ParquetWriter
12+
from cloud2sql.parquet import ArrowModel, ArrowWriter
1313
from pathlib import Path
1414
import shutil
1515
import uuid
@@ -70,12 +70,12 @@ def updater(model: Model) -> SqlDefaultUpdater:
7070

7171
@fixture()
7272
def parquet_writer(model: Model):
73-
parquet_model = ParquetModel(model)
73+
parquet_model = ArrowModel(model)
7474
parquet_model.create_schema([])
7575

7676
p = Path(f"test_parquet_{uuid.uuid4()}")
7777
p.mkdir(exist_ok=True)
78-
yield ParquetWriter(parquet_model, p, 1)
78+
yield ArrowWriter(parquet_model, p, 1, "parquet")
7979
shutil.rmtree(p)
8080

8181

tests/parquet_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from resotoclient.models import Model
22

3-
from cloud2sql.parquet import ParquetModel, ParquetWriter
3+
from cloud2sql.parquet import ArrowModel, ArrowWriter
44

55

66
def test_create_schema(model: Model) -> None:
7-
parquet_model = ParquetModel(model)
7+
parquet_model = ArrowModel(model)
88
parquet_model.create_schema([])
99

1010
assert parquet_model.schemas.keys() == {"some_instance", "some_volume", "link_some_instance_some_volume"}
@@ -32,7 +32,7 @@ def test_create_schema(model: Model) -> None:
3232
assert set(parquet_model.schemas["link_some_instance_some_volume"].names) == {"to_id", "from_id"}
3333

3434

35-
def test_update(parquet_writer: ParquetWriter) -> None:
35+
def test_update(parquet_writer: ArrowWriter) -> None:
3636

3737
parquet_writer.insert_node( # type: ignore
3838
{

0 commit comments

Comments
 (0)