Skip to content

Commit 88216ee

Browse files
feat: [SNOW-2326970] implement dcm test command
1 parent 3f52b9a commit 88216ee

File tree

7 files changed

+697
-0
lines changed

7 files changed

+697
-0
lines changed

src/snowflake/cli/_plugins/dcm/commands.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import json
15+
from pathlib import Path
1416
from typing import List, Optional
1517

1618
import typer
1719
from snowflake.cli._plugins.dcm.manager import DCMProjectManager
20+
from snowflake.cli._plugins.dcm.utils import (
21+
TestResultFormat,
22+
export_test_results,
23+
format_test_failures,
24+
)
1825
from snowflake.cli._plugins.object.command_aliases import add_object_command_aliases
1926
from snowflake.cli._plugins.object.commands import scope_option
2027
from snowflake.cli._plugins.object.manager import ObjectManager
@@ -243,6 +250,66 @@ def drop_deployment(
243250
)
244251

245252

253+
@app.command(requires_connection=True)
254+
def test(
255+
identifier: FQN = dcm_identifier,
256+
export_format: Optional[List[TestResultFormat]] = typer.Option(
257+
None,
258+
"--result-format",
259+
help="Export test results in specified format(s) into directory set with `--output-path`. Can be specified multiple times for multiple formats.",
260+
show_default=False,
261+
),
262+
output_path: Optional[Path] = typer.Option(
263+
None,
264+
"--output-path",
265+
help="Directory where test result files will be saved. Defaults to current directory.",
266+
show_default=False,
267+
),
268+
**options,
269+
):
270+
"""
271+
Test all expectations set for tables, views and dynamic tables defined
272+
in DCM project.
273+
"""
274+
with cli_console.spinner() as spinner:
275+
spinner.add_task(description=f"Testing dcm project {identifier}", total=None)
276+
result = DCMProjectManager().test(project_identifier=identifier)
277+
278+
row = result.fetchone()
279+
if not row:
280+
return MessageResult("No data.")
281+
282+
result_data = row[0]
283+
result_json = (
284+
json.loads(result_data) if isinstance(result_data, str) else result_data
285+
)
286+
287+
expectations = result_json.get("expectations", [])
288+
289+
if not expectations:
290+
return MessageResult("No expectations defined in the project.")
291+
292+
if export_format:
293+
if output_path is None:
294+
output_path = Path().cwd()
295+
saved_files = export_test_results(result_json, export_format, output_path)
296+
if saved_files:
297+
cli_console.step(f"Test results exported to {output_path.resolve()}.")
298+
299+
if result_json.get("status") == "EXPECTATION_VIOLATED":
300+
failed_expectations = [
301+
exp for exp in expectations if exp.get("expectation_violated", False)
302+
]
303+
total_tests = len(expectations)
304+
failed_count = len(failed_expectations)
305+
error_message = format_test_failures(
306+
failed_expectations, total_tests, failed_count
307+
)
308+
raise CliError(error_message)
309+
310+
return MessageResult(f"All {len(expectations)} expectation(s) passed successfully.")
311+
312+
246313
def _get_effective_stage(identifier: FQN, from_location: Optional[str]):
247314
manager = DCMProjectManager()
248315
if not from_location:

src/snowflake/cli/_plugins/dcm/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def drop_deployment(
139139
query += f' "{deployment_name}"'
140140
return self.execute_query(query=query)
141141

142+
def test(self, project_identifier: FQN):
143+
query = f"EXECUTE DCM PROJECT {project_identifier.sql_identifier} TEST ALL"
144+
return self.execute_query(query=query)
145+
142146
@staticmethod
143147
def sync_local_files(
144148
project_identifier: FQN, source_directory: str | None = None
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) 2024 Snowflake Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from enum import Enum
17+
from pathlib import Path
18+
from typing import Any, Dict, List
19+
from xml.etree import ElementTree
20+
21+
22+
class TestResultFormat(str, Enum):
23+
JSON = "json"
24+
JUNIT = "junit"
25+
TAP = "tap"
26+
27+
28+
def format_test_failures(
29+
failed_expectations: list, total_tests: int, failed_count: int
30+
) -> str:
31+
"""Format test failures into a nice error message."""
32+
lines = [
33+
"Failed expectations:",
34+
]
35+
36+
for failed in failed_expectations:
37+
table_name = failed.get("table_name", "Unknown")
38+
expectation_name = failed.get("expectation_name", "Unknown")
39+
metric_name = failed.get("metric_name", "Unknown")
40+
expectation_expr = failed.get("expectation_expression", "N/A")
41+
value = failed.get("value", "N/A")
42+
43+
lines.append(f" Table: {table_name}")
44+
lines.append(f" Expectation: {expectation_name}")
45+
lines.append(f" Metric: {metric_name}")
46+
lines.append(f" Expression: {expectation_expr}")
47+
lines.append(f" Actual value: {value}")
48+
lines.append("")
49+
50+
passed_tests = total_tests - failed_count
51+
lines.append(
52+
f"Tests completed: {passed_tests} passed, {failed_count} failed out of {total_tests} total."
53+
)
54+
55+
return "\n".join(lines)
56+
57+
58+
def _normalize_table_name(table_name: str) -> str:
59+
"""Normalize table name to lowercase with hyphens for file naming."""
60+
return table_name.lower().replace(".", "-").replace("_", "-")
61+
62+
63+
def _group_expectations_by_table(
64+
expectations: List[Dict[str, Any]]
65+
) -> Dict[str, List[Dict[str, Any]]]:
66+
"""Group expectations by table name."""
67+
grouped: dict[str, list[dict[str, Any]]] = {}
68+
for expectation in expectations:
69+
table_name = expectation.get("table_name", "unknown")
70+
if table_name not in grouped:
71+
grouped[table_name] = []
72+
grouped[table_name].append(expectation)
73+
return grouped
74+
75+
76+
def export_test_results_as_json(result_data: Dict[str, Any], output_path: Path) -> None:
77+
"""Export test results as JSON format."""
78+
with open(output_path, "w") as f:
79+
json.dump(result_data, f, indent=2)
80+
81+
82+
def export_test_results_as_junit(
83+
result_data: Dict[str, Any], output_dir: Path
84+
) -> List[Path]:
85+
"""Export test results as JUnit XML format, one file per table."""
86+
expectations = result_data.get("expectations", [])
87+
grouped = _group_expectations_by_table(expectations)
88+
89+
junit_dir = output_dir / "junit"
90+
junit_dir.mkdir(parents=True, exist_ok=True)
91+
92+
saved_files = []
93+
94+
for table_name, table_expectations in grouped.items():
95+
normalized_name = _normalize_table_name(table_name)
96+
output_path = junit_dir / f"{normalized_name}.xml"
97+
98+
testsuites = ElementTree.Element("testsuites")
99+
testsuite = ElementTree.SubElement(
100+
testsuites,
101+
"testsuite",
102+
name=f"DCM Tests - {table_name}",
103+
tests=str(len(table_expectations)),
104+
failures=str(
105+
sum(
106+
1
107+
for e in table_expectations
108+
if e.get("expectation_violated", False)
109+
)
110+
),
111+
errors="0",
112+
skipped="0",
113+
)
114+
115+
for expectation in table_expectations:
116+
expectation_name = expectation.get("expectation_name", "Unknown")
117+
metric_name = expectation.get("metric_name", "Unknown")
118+
119+
testcase = ElementTree.SubElement(
120+
testsuite,
121+
"testcase",
122+
name=expectation_name,
123+
classname=table_name,
124+
)
125+
126+
if expectation.get("expectation_violated", False):
127+
failure = ElementTree.SubElement(
128+
testcase,
129+
"failure",
130+
message=f"Expectation '{expectation_name}' violated",
131+
type="AssertionError",
132+
)
133+
expectation_expr = expectation.get("expectation_expression", "N/A")
134+
value = expectation.get("value", "N/A")
135+
failure.text = (
136+
f"Metric: {metric_name}\n"
137+
f"Expression: {expectation_expr}\n"
138+
f"Actual value: {value}"
139+
)
140+
141+
tree = ElementTree.ElementTree(testsuites)
142+
ElementTree.indent(tree, space=" ")
143+
tree.write(output_path, encoding="utf-8", xml_declaration=True)
144+
saved_files.append(output_path)
145+
146+
return saved_files
147+
148+
149+
def export_test_results_as_tap(
150+
result_data: Dict[str, Any], output_dir: Path
151+
) -> List[Path]:
152+
"""Export test results as TAP (Test Anything Protocol) format, one file per table."""
153+
expectations = result_data.get("expectations", [])
154+
grouped = _group_expectations_by_table(expectations)
155+
156+
tap_dir = output_dir / "tap"
157+
tap_dir.mkdir(parents=True, exist_ok=True)
158+
159+
saved_files = []
160+
161+
for table_name, table_expectations in grouped.items():
162+
normalized_name = _normalize_table_name(table_name)
163+
output_path = tap_dir / f"{normalized_name}.tap"
164+
165+
lines = [f"1..{len(table_expectations)}"]
166+
167+
for idx, expectation in enumerate(table_expectations, start=1):
168+
expectation_name = expectation.get("expectation_name", "Unknown")
169+
metric_name = expectation.get("metric_name", "Unknown")
170+
171+
if expectation.get("expectation_violated", False):
172+
lines.append(f"not ok {idx} - {expectation_name}")
173+
lines.append(f" ---")
174+
lines.append(f" message: Expectation '{expectation_name}' violated")
175+
lines.append(f" severity: fail")
176+
lines.append(f" data:")
177+
lines.append(f" table: {table_name}")
178+
lines.append(f" metric: {metric_name}")
179+
lines.append(
180+
f" expression: {expectation.get('expectation_expression', 'N/A')}"
181+
)
182+
lines.append(f" actual_value: {expectation.get('value', 'N/A')}")
183+
lines.append(f" ...")
184+
else:
185+
lines.append(f"ok {idx} - {expectation_name}")
186+
187+
with open(output_path, "w") as f:
188+
f.write("\n".join(lines) + "\n")
189+
190+
saved_files.append(output_path)
191+
192+
return saved_files
193+
194+
195+
def export_test_results(
196+
result_data: Dict[str, Any],
197+
formats: List[TestResultFormat],
198+
output_dir: Path,
199+
) -> List[Path]:
200+
"""
201+
Export test results in multiple formats.
202+
203+
Args:
204+
result_data: The test result data from the backend
205+
formats: List of formats to export to
206+
output_dir: Directory to save the results
207+
208+
Returns:
209+
List of paths where results were saved
210+
"""
211+
output_dir.mkdir(parents=True, exist_ok=True)
212+
saved_files = []
213+
214+
for format_type in formats:
215+
if format_type == TestResultFormat.JSON:
216+
output_path = output_dir / "test_result.json"
217+
export_test_results_as_json(result_data, output_path)
218+
saved_files.append(output_path)
219+
elif format_type == TestResultFormat.JUNIT:
220+
files = export_test_results_as_junit(result_data, output_dir)
221+
saved_files.extend(files)
222+
elif format_type == TestResultFormat.TAP:
223+
files = export_test_results_as_tap(result_data, output_dir)
224+
saved_files.extend(files)
225+
226+
return saved_files

0 commit comments

Comments
 (0)