Skip to content

Commit b7ccfd1

Browse files
authored
Support types in Parquet (#24)
1 parent c469a8e commit b7ccfd1

File tree

10 files changed

+461
-209
lines changed

10 files changed

+461
-209
lines changed

cloud2sql/arrow/model.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from resotoclient.models import Kind, Model
2+
from typing import Dict, List, Tuple, Literal
3+
import pyarrow as pa
4+
from cloud2sql.schema_utils import (
5+
base_kinds,
6+
get_table_name,
7+
get_link_table_name,
8+
kind_properties,
9+
)
10+
from cloud2sql.arrow.type_converter import parquet_pyarrow_type, csv_pyarrow_type
11+
from functools import partial
12+
13+
14+
class ArrowModel:
15+
def __init__(self, model: Model, output_format: Literal["parquet", "csv"]):
16+
self.model = model
17+
self.table_kinds = [
18+
kind
19+
for kind in model.kinds.values()
20+
if kind.aggregate_root and kind.runtime_kind is None and kind.fqn not in base_kinds
21+
]
22+
self.schemas: Dict[str, pa.Schema] = {}
23+
if output_format == "parquet":
24+
self.pyarrow_type = partial(parquet_pyarrow_type, model=model)
25+
elif output_format == "csv":
26+
self.pyarrow_type = partial(csv_pyarrow_type, model=model)
27+
else:
28+
raise Exception(f"Unknown output format {output_format}")
29+
30+
def create_schema(self, edges: List[Tuple[str, str]]) -> None:
31+
def table_schema(kind: Kind) -> None:
32+
table_name = get_table_name(kind.fqn, with_tmp_prefix=False)
33+
if table_name not in self.schemas:
34+
properties, _ = kind_properties(kind, self.model)
35+
schema = pa.schema(
36+
[
37+
pa.field("_id", pa.string()),
38+
*[pa.field(p.name, self.pyarrow_type(p.kind)) for p in properties],
39+
]
40+
)
41+
self.schemas[table_name] = schema
42+
43+
def link_table_schema(from_kind: str, to_kind: str) -> None:
44+
from_table = get_table_name(from_kind, with_tmp_prefix=False)
45+
to_table = get_table_name(to_kind, with_tmp_prefix=False)
46+
link_table = get_link_table_name(from_kind, to_kind, with_tmp_prefix=False)
47+
if link_table not in self.schemas and from_table in self.schemas and to_table in self.schemas:
48+
schema = pa.schema(
49+
[
50+
pa.field("from_id", pa.string()),
51+
pa.field("to_id", pa.string()),
52+
]
53+
)
54+
self.schemas[link_table] = schema
55+
56+
def link_table_schema_from_successors(kind: Kind) -> None:
57+
_, successors = kind_properties(kind, self.model)
58+
# create link table for all linked entities
59+
for successor in successors:
60+
link_table_schema(kind.fqn, successor)
61+
62+
# step 1: create tables for all kinds
63+
for kind in self.table_kinds:
64+
table_schema(kind)
65+
# step 2: create link tables for all kinds
66+
for kind in self.table_kinds:
67+
link_table_schema_from_successors(kind)
68+
# step 3: create link tables for all seen edges
69+
for from_kind, to_kind in edges:
70+
link_table_schema(from_kind, to_kind)
71+
72+
return None

cloud2sql/arrow/type_converter.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pyarrow as pa
2+
from resotoclient.models import Model
3+
from cloud2sql.schema_utils import kind_properties
4+
5+
6+
def parquet_pyarrow_type(kind: str, model: Model) -> pa.lib.DataType:
7+
if "[]" in kind:
8+
return pa.list_(parquet_pyarrow_type(kind.strip("[]"), model))
9+
elif kind.startswith("dictionary"):
10+
(key_kind, value_kind) = kind.strip("dictionary").strip("[]").split(",")
11+
return pa.map_(parquet_pyarrow_type(key_kind.strip(), model), parquet_pyarrow_type(value_kind.strip(), model))
12+
elif kind == "int32":
13+
return pa.int32()
14+
elif kind == "int64":
15+
return pa.int64()
16+
elif kind == "float":
17+
pa.float32()
18+
elif kind == "double":
19+
return pa.float64()
20+
elif kind in {"string", "datetime", "date", "duration", "any"}:
21+
return pa.string()
22+
elif kind == "boolean":
23+
return pa.bool_()
24+
elif kind in model.kinds:
25+
nested_kind = model.kinds[kind]
26+
if nested_kind.runtime_kind is not None:
27+
return parquet_pyarrow_type(nested_kind.runtime_kind, model)
28+
29+
properties, _ = kind_properties(nested_kind, model)
30+
return pa.struct([pa.field(p.name, parquet_pyarrow_type(p.kind, model)) for p in properties])
31+
else:
32+
raise Exception(f"Unknown kind {kind}")
33+
34+
35+
def csv_pyarrow_type(kind: str, model: Model) -> pa.lib.DataType:
36+
if "[]" in kind:
37+
return pa.string()
38+
elif kind.startswith("dictionary"):
39+
return pa.string()
40+
elif kind == "int32":
41+
return pa.int32()
42+
elif kind == "int64":
43+
return pa.int64()
44+
elif kind == "float":
45+
pa.float32()
46+
elif kind == "double":
47+
return pa.float64()
48+
elif kind in {"string", "datetime", "date", "duration", "any"}:
49+
return pa.string()
50+
elif kind == "boolean":
51+
return pa.bool_()
52+
elif kind in model.kinds:
53+
return pa.string()
54+
else:
55+
raise Exception(f"Unknown kind {kind}")

cloud2sql/arrow/writer.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from typing import Dict, List, Any, NamedTuple, Optional, final, Literal
2+
import pyarrow.csv as csv
3+
from dataclasses import dataclass
4+
import dataclasses
5+
from abc import ABC
6+
import pyarrow.parquet as pq
7+
import pyarrow as pa
8+
import json
9+
from pathlib import Path
10+
from cloud2sql.arrow.model import ArrowModel
11+
from cloud2sql.schema_utils import insert_node
12+
from resotoclient.models import JsObject
13+
14+
15+
class WriteResult(NamedTuple):
16+
table_name: str
17+
18+
19+
class FileWriter(ABC):
20+
pass
21+
22+
23+
@final
24+
@dataclass(frozen=True)
25+
class Parquet(FileWriter):
26+
parquet_writer: pq.ParquetWriter
27+
28+
29+
@final
30+
@dataclass(frozen=True)
31+
class CSV(FileWriter):
32+
csv_writer: csv.CSVWriter
33+
34+
35+
@final
36+
@dataclass
37+
class ArrowBatch:
38+
table_name: str
39+
rows: List[Dict[str, Any]]
40+
schema: pa.Schema
41+
writer: FileWriter
42+
43+
44+
class ConversionTarget(ABC):
45+
pass
46+
47+
48+
@final
49+
@dataclass(frozen=True)
50+
class ParquetMap(ConversionTarget):
51+
convert_values_to_str: bool
52+
53+
54+
@final
55+
@dataclass(frozen=True)
56+
class ParquetString(ConversionTarget):
57+
pass
58+
59+
60+
@dataclass
61+
class NormalizationPath:
62+
path: List[Optional[str]]
63+
convert_to: ConversionTarget
64+
65+
66+
# workaround until fix is merged https://issues.apache.org/jira/browse/ARROW-17832
67+
#
68+
# here we collect the paths to the JSON object fields that we want to convert to arrow types
69+
# so that later we will do the transformations
70+
#
71+
# currently we do dict -> list[(key, value)] and converting values which are defined as strings
72+
# in the scheme to strings/json strings
73+
def collect_normalization_paths(schema: pa.Schema) -> List[NormalizationPath]:
74+
paths: List[NormalizationPath] = []
75+
76+
def collect_paths_to_maps_helper(path: List[Optional[str]], typ: pa.DataType) -> None:
77+
# if we see a map, then full stop. we add the path to the list
78+
# if the value type is string, we remember that too
79+
if isinstance(typ, pa.lib.MapType):
80+
stringify_items = pa.types.is_string(typ.item_type)
81+
normalization_path = NormalizationPath(path, ParquetMap(stringify_items))
82+
paths.append(normalization_path)
83+
# structs are traversed but they have no interest for us
84+
elif isinstance(typ, pa.lib.StructType):
85+
for field_idx in range(0, typ.num_fields):
86+
field = typ.field(field_idx)
87+
collect_paths_to_maps_helper(path + [field.name], field.type)
88+
# the lists traversed too. None will is added to the path to be consumed by the recursion
89+
# in order to reach the correct level
90+
elif isinstance(typ, pa.lib.ListType):
91+
collect_paths_to_maps_helper(path + [None], typ.value_type)
92+
# if we see a string, then we stop and add a path to the list
93+
elif pa.types.is_string(typ):
94+
normalization_path = NormalizationPath(path, ParquetString())
95+
paths.append(normalization_path)
96+
97+
# bootstrap the recursion
98+
for idx, typ in enumerate(schema.types):
99+
collect_paths_to_maps_helper([schema.names[idx]], typ)
100+
101+
return paths
102+
103+
104+
def normalize(npath: NormalizationPath, obj: Any) -> Any:
105+
path = npath.path
106+
reached_target = len(path) == 0
107+
108+
# we're on the correct node, time to convert it into something
109+
if reached_target:
110+
if isinstance(npath.convert_to, ParquetString):
111+
# everything that should be a string is either a string or a json string
112+
return obj if isinstance(obj, str) else json.dumps(obj)
113+
elif isinstance(npath.convert_to, ParquetMap):
114+
# we can only convert dicts to maps. if it is not the case then it is a bug
115+
if not isinstance(obj, dict):
116+
raise Exception(f"Expected dict, got {type(obj)}. path: {npath}")
117+
118+
def value_to_string(v: Any) -> str:
119+
if isinstance(v, str):
120+
return v
121+
else:
122+
return json.dumps(v)
123+
124+
# in case the map should contain string values, we convert them to strings
125+
if npath.convert_to.convert_values_to_str:
126+
return [(k, value_to_string(v)) for k, v in obj.items()]
127+
else:
128+
return list(obj.items())
129+
# we're not at the target node yet, so we traverse the tree deeper
130+
else:
131+
# if we see a dict, we try to go deeper in case it contains the key we are looking for
132+
# otherwise we return the object as is. This is valid because the fields are optional
133+
if isinstance(obj, dict):
134+
key = path[0]
135+
if key in obj:
136+
# consume the current element of the path
137+
new_npath = dataclasses.replace(npath, path=path[1:])
138+
obj[key] = normalize(new_npath, obj[key])
139+
return obj
140+
# in case of a list, we process all its elements
141+
elif isinstance(obj, list):
142+
# check that the path is correct
143+
assert path[0] is None
144+
# consume the current element of the path
145+
new_npath = dataclasses.replace(npath, path=path[1:])
146+
return [normalize(new_npath, v) for v in obj]
147+
else:
148+
raise Exception(f"Unexpected object type {type(obj)}, path: {npath}")
149+
150+
151+
def write_batch_to_file(batch: ArrowBatch) -> ArrowBatch:
152+
153+
to_normalize = collect_normalization_paths(batch.schema)
154+
155+
for row in batch.rows:
156+
for path in to_normalize:
157+
normalize(path, row)
158+
159+
pa_table = pa.Table.from_pylist(batch.rows, batch.schema)
160+
if isinstance(batch.writer, Parquet):
161+
batch.writer.parquet_writer.write_table(pa_table)
162+
elif isinstance(batch.writer, CSV):
163+
batch.writer.csv_writer.write_table(pa_table)
164+
else:
165+
raise ValueError(f"Unknown format {batch.writer}")
166+
return ArrowBatch(table_name=batch.table_name, rows=[], schema=batch.schema, writer=batch.writer)
167+
168+
169+
def close_writer(batch: ArrowBatch) -> None:
170+
if isinstance(batch.writer, Parquet):
171+
batch.writer.parquet_writer.close()
172+
elif isinstance(batch.writer, CSV):
173+
batch.writer.csv_writer.close()
174+
else:
175+
raise ValueError(f"Unknown format {batch.writer}")
176+
177+
178+
def new_writer(format: Literal["parquet", "csv"], table_name: str, schema: pa.Schema, result_dir: Path) -> FileWriter:
179+
def ensure_path(path: Path) -> Path:
180+
path.mkdir(parents=True, exist_ok=True)
181+
return path
182+
183+
if format == "parquet":
184+
return Parquet(pq.ParquetWriter(Path(ensure_path(result_dir), f"{table_name}.parquet"), schema=schema))
185+
elif format == "csv":
186+
return CSV(csv.CSVWriter(Path(ensure_path(result_dir), f"{table_name}.csv"), schema=schema))
187+
else:
188+
raise ValueError(f"Unknown format {format}")
189+
190+
191+
class ArrowWriter:
192+
def __init__(
193+
self, model: ArrowModel, result_directory: Path, rows_per_batch: int, output_format: Literal["parquet", "csv"]
194+
):
195+
self.model = model
196+
self.kind_by_id: Dict[str, str] = {}
197+
self.batches: Dict[str, ArrowBatch] = {}
198+
self.rows_per_batch: int = rows_per_batch
199+
self.result_directory: Path = result_directory
200+
self.output_format: Literal["parquet", "csv"] = output_format
201+
202+
def insert_value(self, table_name: str, values: Any) -> Optional[WriteResult]:
203+
if self.model.schemas.get(table_name):
204+
schema = self.model.schemas[table_name]
205+
batch = self.batches.get(
206+
table_name,
207+
ArrowBatch(
208+
table_name,
209+
[],
210+
schema,
211+
new_writer(self.output_format, table_name, schema, self.result_directory),
212+
),
213+
)
214+
215+
batch.rows.append(values)
216+
self.batches[table_name] = batch
217+
return WriteResult(table_name)
218+
return None
219+
220+
def insert_node(self, node: JsObject) -> None:
221+
result = insert_node(
222+
node,
223+
self.kind_by_id,
224+
self.insert_value,
225+
with_tmp_prefix=False,
226+
)
227+
should_write_batch = result and len(self.batches[result.table_name].rows) > self.rows_per_batch
228+
if result and should_write_batch:
229+
batch = self.batches[result.table_name]
230+
self.batches[result.table_name] = write_batch_to_file(batch)
231+
232+
def close(self) -> None:
233+
for table_name, batch in self.batches.items():
234+
batch = write_batch_to_file(batch)
235+
self.batches[table_name] = batch
236+
close_writer(batch)

cloud2sql/collect_plugins.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from cloud2sql.sql import SqlUpdater, sql_updater
3535

3636
try:
37-
from cloud2sql.parquet import ArrowModel, ArrowWriter
37+
from cloud2sql.arrow.model import ArrowModel
38+
from cloud2sql.arrow.writer import ArrowWriter
3839
except ImportError:
3940
pass
4041

@@ -127,10 +128,10 @@ def collect_to_file(
127128
collector.collect()
128129
# read the kinds created from this collector
129130
kinds = [from_json(m, Kind) for m in collector.graph.export_model(walk_subclasses=False)]
130-
model = ArrowModel(Model({k.fqn: k for k in kinds}))
131+
model = ArrowModel(Model({k.fqn: k for k in kinds}), config.format)
131132
node_edge_count = len(collector.graph.nodes) + len(collector.graph.edges)
132133
ne_current = 0
133-
progress_update = node_edge_count // 100
134+
progress_update = max(node_edge_count // 100, 1)
134135
feedback.progress_done("sync_db", 0, node_edge_count, context=[collector.cloud])
135136

136137
# group all edges by kind of from/to

0 commit comments

Comments
 (0)