Skip to content

Commit 1b91b8b

Browse files
authored
feat: Support Snowflake (#9)
1 parent 75e4794 commit 1b91b8b

File tree

11 files changed

+303
-174
lines changed

11 files changed

+303
-174
lines changed

.github/workflows/build_and_publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- name: Setup Python
2222
uses: actions/setup-python@v4
2323
with:
24-
python-version: '3.11'
24+
python-version: '3.10'
2525
architecture: 'x64'
2626

2727
- name: Restore dependency cache

cloud2sql/__main__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
from resotolib.logger import setup_logger
55
from sqlalchemy import create_engine
66
from sqlalchemy.engine import Engine
7-
87
from cloud2sql.collect_plugins import collect_from_plugins
98

9+
# Will fail in case snowflake is not installed - which is fine.
10+
try:
11+
from cloud2sql.snowflake import SnowflakeUpdater # noqa:F401
12+
except ImportError:
13+
pass
14+
1015
log = getLogger("cloud2sql")
1116

1217

@@ -40,9 +45,15 @@ def collect(engine: Engine, args: Namespace) -> None:
4045

4146
def main() -> None:
4247
args = parse_args()
43-
setup_logger("cloud2sql", level=args.log_level, force=True)
44-
engine = create_engine(args.db)
45-
collect(engine, args)
48+
try:
49+
setup_logger("cloud2sql", level=args.log_level, force=True)
50+
engine = create_engine(args.db)
51+
collect(engine, args)
52+
except Exception as e:
53+
if args.debug: # raise exception and show complete tracelog
54+
raise e
55+
else:
56+
print(f"Error syncing data to database: {e}")
4657

4758

4859
if __name__ == "__main__":

cloud2sql/collect_plugins.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import concurrent
22
import multiprocessing
3+
from collections import defaultdict
34
from concurrent.futures import ThreadPoolExecutor, Future
45
from contextlib import suppress
56
from logging import getLogger
@@ -25,9 +26,9 @@
2526
from sqlalchemy.engine import Engine
2627

2728
from cloud2sql.show_progress import CollectInfo
28-
from cloud2sql.sql import SqlModel, SqlUpdater
29+
from cloud2sql.sql import SqlUpdater, sql_updater
2930

30-
log = getLogger("cloud2sql")
31+
log = getLogger("resoto.cloud2sql")
3132

3233

3334
def collectors(raw_config: Json, feedback: CoreFeedback) -> Dict[str, BaseCollectorPlugin]:
@@ -62,24 +63,25 @@ def collect(collector: BaseCollectorPlugin, engine: Engine, feedback: CoreFeedba
6263
collector.collect()
6364
# read the kinds created from this collector
6465
kinds = [from_json(m, Kind) for m in collector.graph.export_model(walk_subclasses=False)]
65-
model = SqlModel(Model({k.fqn: k for k in kinds}))
66+
updater = sql_updater(Model({k.fqn: k for k in kinds}), engine)
6667
node_edge_count = len(collector.graph.nodes) + len(collector.graph.edges)
67-
ne_count = iter(range(0, node_edge_count))
68-
progress_update = max(node_edge_count // 100, 50)
68+
ne_count = 0
6969
schema = f"create temp tables {engine.dialect.name}"
7070
syncdb = f"synchronize {engine.dialect.name}"
7171
feedback.progress_done(schema, 0, 1, context=[collector.cloud])
7272
feedback.progress_done(syncdb, 0, node_edge_count, context=[collector.cloud])
7373
with engine.connect() as conn:
7474
with conn.begin():
7575
# create the ddl metadata from the kinds
76-
model.create_schema(conn, args)
76+
updater.create_schema(conn, args)
7777
feedback.progress_done(schema, 1, 1, context=[collector.cloud])
78-
# ingest the data
79-
updater = SqlUpdater(model)
78+
79+
# group all nodes by kind
80+
nodes_by_kind = defaultdict(list)
8081
node: BaseResource
8182
for node in collector.graph.nodes:
8283
node._graph = collector.graph
84+
# create an exported node with the same scheme as resotocore
8385
exported = node_to_dict(node)
8486
exported["type"] = "node"
8587
exported["ancestors"] = {
@@ -88,17 +90,29 @@ def collect(collector: BaseCollectorPlugin, engine: Engine, feedback: CoreFeedba
8890
"region": {"reported": {"id": node.region().name}},
8991
"zone": {"reported": {"id": node.zone().name}},
9092
}
91-
stmt = updater.insert_node(exported)
92-
if stmt is not None:
93-
conn.execute(stmt)
94-
if (nx := next(ne_count)) % progress_update == 0:
95-
feedback.progress_done(syncdb, nx, node_edge_count, context=[collector.cloud])
93+
nodes_by_kind[node.kind].append(exported)
94+
95+
# insert batches of nodes by kind
96+
for kind, nodes in nodes_by_kind.items():
97+
log.info(f"Inserting {len(nodes)} nodes of kind {kind}")
98+
for insert in updater.insert_nodes(kind, nodes):
99+
conn.execute(insert)
100+
ne_count += len(nodes)
101+
feedback.progress_done(syncdb, ne_count, node_edge_count, context=[collector.cloud])
102+
103+
# group all nodes by kind of from/to
104+
edges_by_kind = defaultdict(list)
96105
for from_node, to_node, _ in collector.graph.edges:
97-
stmt = updater.insert_node({"from": from_node.chksum, "to": to_node.chksum, "type": "edge"})
98-
if stmt is not None:
99-
conn.execute(stmt)
100-
if (nx := next(ne_count)) % progress_update == 0:
101-
feedback.progress_done(syncdb, nx, node_edge_count, context=[collector.cloud])
106+
edge_node = {"from": from_node.chksum, "to": to_node.chksum, "type": "edge"}
107+
edges_by_kind[(from_node.kind, to_node.kind)].append(edge_node)
108+
109+
# insert batches of edges by from/to kind
110+
for from_to, nodes in edges_by_kind.items():
111+
log.info(f"Inserting {len(nodes)} edges from {from_to[0]} to {from_to[1]}")
112+
for insert in updater.insert_edges(from_to, nodes):
113+
conn.execute(insert)
114+
ne_count += len(nodes)
115+
feedback.progress_done(syncdb, ne_count, node_edge_count, context=[collector.cloud])
102116
feedback.progress_done(collector.cloud, 1, 1)
103117

104118

@@ -131,7 +145,10 @@ def collect_from_plugins(engine: Engine, args: Namespace) -> None:
131145
for future in concurrent.futures.as_completed(futures):
132146
future.result()
133147
# when all collectors are done, we can swap all temp tables
134-
SqlModel.swap_temp_tables(engine)
148+
swap_tables = "Make latest snapshot available"
149+
feedback.progress_done(swap_tables, 0, 1)
150+
SqlUpdater.swap_temp_tables(engine)
151+
feedback.progress_done(swap_tables, 1, 1)
135152
except Exception as e:
136153
# set end and wait for live to finish, otherwise the cursor is not reset
137154
end.set()

cloud2sql/collect_resoto.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

cloud2sql/snowflake.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import json
2+
import logging
3+
from typing import Any, List, Iterator
4+
5+
from resotoclient import Model
6+
from resotoclient.models import Property
7+
from resotolib.types import Json
8+
from snowflake.sqlalchemy import ARRAY, OBJECT
9+
from sqlalchemy import Integer, Float, String, Boolean, column
10+
from sqlalchemy import select
11+
from sqlalchemy.sql import Values
12+
from sqlalchemy.sql.dml import ValuesBase
13+
14+
from cloud2sql.sql import SqlDefaultUpdater, DialectUpdater
15+
16+
log = logging.getLogger("resoto.cloud2sql.snowflake")
17+
18+
19+
def kind_to_snowflake_type(kind_name: str, model: Model) -> Any: # Type[TypeEngine[Any]]
20+
"""
21+
Map internal kinds to snowflake types.
22+
More or less the default mapping, but with some special cases for OBJECT and ARRAY types.
23+
"""
24+
kind = model.kinds.get(kind_name)
25+
if "[]" in kind_name:
26+
return ARRAY
27+
elif kind_name.startswith("dict"):
28+
return OBJECT
29+
elif kind_name == "any":
30+
return OBJECT
31+
elif kind_name in ("int32", "int64"):
32+
return Integer
33+
elif kind_name in "float":
34+
return Float
35+
elif kind_name in "double":
36+
return Float # use Double with sqlalchemy 2
37+
elif kind_name in ("string", "date", "datetime", "duration"):
38+
return String
39+
elif kind_name == "boolean":
40+
return Boolean
41+
elif kind.runtime_kind is not None: # refined simple type like enum
42+
return kind_to_snowflake_type(kind.runtime_kind, model)
43+
elif kind.properties: # complex kind
44+
return OBJECT
45+
else:
46+
raise ValueError(f"Not able to handle kind {kind_name}")
47+
48+
49+
class SnowflakeUpdater(SqlDefaultUpdater):
50+
"""
51+
This updater synchronizes resource data to snowflake https://www.snowflake.com
52+
Snowflake needs special handling, since it does not support default json or array types.
53+
It also does not understand json or array types as bind parameters.
54+
This updater handles those shortcomings by using special insert statements.
55+
"""
56+
57+
def __init__(self, model: Model, **args: Any) -> None:
58+
super().__init__(model, **args)
59+
self.column_types_fn = kind_to_snowflake_type
60+
61+
def insert_nodes(self, kind: str, nodes: List[Json]) -> Iterator[ValuesBase]:
62+
kp, _ = self.kind_properties(self.model.kinds[kind])
63+
kind_props = [Property("_id", "string")] + kp
64+
select_array = []
65+
column_definitions = []
66+
prop_is_json = {}
67+
68+
# Inserting structured data into Snowflake requires a bit of work. General scheme:
69+
# insert into TBL(col_string, col_json) SELECT column1, parse_json(column2) from values('a', '{"b":1}');
70+
# All json and array elements need to be json encoded and parsed on the server side again.
71+
for num, prop in enumerate(kind_props):
72+
name = f"column{num+1}"
73+
select_array.append(prop.name)
74+
snowflake_kind = kind_to_snowflake_type(prop.kind, self.model)
75+
if snowflake_kind in (ARRAY, OBJECT):
76+
column_definitions.append(column(f"parse_json({name})", is_literal=True))
77+
prop_is_json[prop.name] = True
78+
else:
79+
column_definitions.append(column(name))
80+
81+
def values_tuple(node: Json) -> List[Any]:
82+
nj = self.node_to_json(node)
83+
# make sure to use the same order as in select_array
84+
return [json.dumps(nj.get(p.name)) if prop_is_json.get(p.name) else nj.get(p.name) for p in kind_props]
85+
86+
if (table := self.metadata.tables.get(self.table_name(kind))) is not None:
87+
for batch in (nodes[i : i + self.insert_batch_size] for i in range(0, len(nodes), self.insert_batch_size)):
88+
converted = [values_tuple(node) for node in batch]
89+
yield table.insert().from_select(select_array, select(Values(*column_definitions).data(converted)))
90+
91+
92+
# register this updater for the snowflake dialect, when snowflake is installed
93+
DialectUpdater["snowflake"] = SnowflakeUpdater

0 commit comments

Comments
 (0)