Skip to content

Commit b996a9b

Browse files
feat: [SNOW-1890085] implement dbt deploy command
1 parent a207b94 commit b996a9b

File tree

3 files changed

+188
-77
lines changed

3 files changed

+188
-77
lines changed

src/snowflake/cli/_plugins/dbt/commands.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from typing import Annotated
18+
from pathlib import Path
19+
from typing import Optional
1920

2021
import typer
2122
from snowflake.cli._plugins.dbt.constants import DBT_COMMANDS
2223
from snowflake.cli._plugins.dbt.manager import DBTManager
2324
from snowflake.cli.api.commands.decorators import global_options_with_connection
25+
from snowflake.cli.api.commands.flags import identifier_argument
2426
from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
27+
from snowflake.cli.api.identifiers import FQN
2528
from snowflake.cli.api.output.types import CommandResult, QueryResult
2629

2730
app = SnowTyperFactory(
@@ -32,6 +35,9 @@
3235
log = logging.getLogger(__name__)
3336

3437

38+
DBTNameArgument = identifier_argument(sf_object="DBT Object", example="my_pipeline")
39+
40+
3541
@app.command(
3642
"list",
3743
requires_connection=True,
@@ -45,6 +51,34 @@ def list_dbts(
4551
return QueryResult(DBTManager().list())
4652

4753

54+
@app.command(
55+
"deploy",
56+
requires_connection=True,
57+
)
58+
def deploy_dbt(
59+
name: FQN = DBTNameArgument,
60+
source: Optional[str] = typer.Option(
61+
help="Path to directory containing dbt files to deploy. Defaults to current working directory.",
62+
show_default=False,
63+
default=None,
64+
),
65+
force: Optional[bool] = typer.Option(
66+
False,
67+
help="Overwrites conflicting files in the project, if any.",
68+
),
69+
**options,
70+
) -> CommandResult:
71+
"""
72+
Copy dbt files and create or update dbt on Snowflake project.
73+
"""
74+
# TODO: options for DBT version?
75+
if source is None:
76+
path = Path.cwd()
77+
else:
78+
path = Path(source)
79+
return QueryResult(DBTManager().deploy(path.resolve(), name, force=force))
80+
81+
4882
# `execute` is a pass through command group, meaning that all params after command should be passed over as they are,
4983
# suppressing usual CLI behaviour for displaying help or formatting options.
5084
dbt_execute_app = SnowTyperFactory(
@@ -57,9 +91,7 @@ def list_dbts(
5791
@dbt_execute_app.callback()
5892
@global_options_with_connection
5993
def before_callback(
60-
name: Annotated[
61-
str, typer.Argument(help="Name of the dbt object to execute command on.")
62-
],
94+
name: FQN = DBTNameArgument,
6395
**options,
6496
):
6597
"""Handles global options passed before the command and takes pipeline name to be accessed through child context later"""

src/snowflake/cli/_plugins/dbt/manager.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
from __future__ import annotations
1616

17+
from pathlib import Path
18+
19+
from click import ClickException
20+
from snowflake.cli._plugins.stage.manager import StageManager
21+
from snowflake.cli.api.console import cli_console
22+
from snowflake.cli.api.identifiers import FQN
1723
from snowflake.cli.api.sql_execution import SqlExecutionMixin
1824
from snowflake.connector.cursor import SnowflakeCursor
1925

@@ -23,6 +29,25 @@ def list(self) -> SnowflakeCursor: # noqa: A003
2329
query = "SHOW DBT"
2430
return self.execute_query(query)
2531

32+
def deploy(self, path: Path, name: FQN, force: bool) -> SnowflakeCursor:
33+
# TODO: what to do with force?
34+
if not path.joinpath("dbt_project.yml").exists():
35+
raise ClickException(f"dbt_project.yml does not exist in provided path.")
36+
37+
with cli_console.phase("Creating temporary stage"):
38+
stage_manager = StageManager()
39+
stage_fqn = FQN.from_string(f"dbt_{name}_stage").using_context()
40+
stage_name = stage_manager.get_standard_stage_prefix(stage_fqn)
41+
stage_manager.create(stage_fqn, temporary=True)
42+
43+
with cli_console.phase("Copying project files to stage"):
44+
results = list(stage_manager.put_recursive(path, stage_name))
45+
cli_console.step(f"Copied {len(results)} files")
46+
47+
with cli_console.phase("Creating DBT project"):
48+
query = f"CREATE OR REPLACE DBT {name} FROM {stage_name}"
49+
return self.execute_query(query)
50+
2651
def execute(self, dbt_command: str, name: str, *dbt_cli_args):
2752
query = f"EXECUTE DBT {name} {dbt_command}"
2853
if dbt_cli_args:

tests/dbt/test_dbt_commands.py

Lines changed: 127 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from unittest import mock
1618

1719
import pytest
20+
from snowflake.cli.api.identifiers import FQN
1821

1922

2023
@pytest.fixture
@@ -26,76 +29,127 @@ def mock_connect(mock_ctx):
2629
yield _fixture
2730

2831

29-
def test_dbt_list(mock_connect, runner):
30-
31-
result = runner.invoke(["dbt", "list"])
32-
33-
assert result.exit_code == 0, result.output
34-
assert mock_connect.mocked_ctx.get_query() == "SHOW DBT"
35-
36-
37-
@pytest.mark.parametrize(
38-
"args,expected_query",
39-
[
40-
pytest.param(
41-
[
42-
"dbt",
43-
"execute",
44-
"pipeline_name",
45-
"test",
46-
],
47-
"EXECUTE DBT pipeline_name test",
48-
id="simple-command",
49-
),
50-
pytest.param(
51-
[
52-
"dbt",
53-
"execute",
54-
"pipeline_name",
55-
"run",
56-
"-f",
57-
"--select @source:snowplow,tag:nightly models/export",
58-
],
59-
"EXECUTE DBT pipeline_name run -f --select @source:snowplow,tag:nightly models/export",
60-
id="with-dbt-options",
61-
),
62-
pytest.param(
63-
["dbt", "execute", "pipeline_name", "compile", "--vars '{foo:bar}'"],
64-
"EXECUTE DBT pipeline_name compile --vars '{foo:bar}'",
65-
id="with-dbt-vars",
66-
),
67-
pytest.param(
68-
[
69-
"dbt",
70-
"execute",
71-
"pipeline_name",
72-
"compile",
73-
"--format=TXT", # collision with CLI's option; unsupported option
74-
"-v", # collision with CLI's option
75-
"-h",
76-
"--debug",
77-
"--info",
78-
"--config-file=/",
79-
],
80-
"EXECUTE DBT pipeline_name compile --format=TXT -v -h --debug --info --config-file=/",
81-
id="with-dbt-conflicting-options",
82-
),
83-
pytest.param(
84-
[
85-
"dbt",
86-
"execute",
87-
"--format=JSON",
88-
"pipeline_name",
89-
"compile",
90-
],
91-
"EXECUTE DBT pipeline_name compile",
92-
id="with-cli-flag",
93-
),
94-
],
95-
)
96-
def test_dbt_execute(mock_connect, runner, args, expected_query):
97-
98-
result = runner.invoke(args)
99-
100-
assert result.exit_code == 0, result.output
101-
assert mock_connect.mocked_ctx.get_query() == expected_query
32+
class TestDBTList:
33+
def test_dbt_list(self, mock_connect, runner):
34+
35+
result = runner.invoke(["dbt", "list"])
36+
37+
assert result.exit_code == 0, result.output
38+
assert mock_connect.mocked_ctx.get_query() == "SHOW DBT"
39+
40+
41+
class TestDBTDeploy:
42+
@pytest.fixture
43+
def dbt_project_path(self, tmp_path_factory):
44+
source_path = tmp_path_factory.mktemp("dbt_project")
45+
dbt_file = source_path / "dbt_project.yml"
46+
dbt_file.touch()
47+
yield source_path
48+
49+
@pytest.fixture
50+
def mock_cli_console(self):
51+
with mock.patch("snowflake.cli.api.console") as _fixture:
52+
yield _fixture
53+
54+
@mock.patch("snowflake.cli._plugins.dbt.manager.StageManager.put_recursive")
55+
@mock.patch("snowflake.cli._plugins.dbt.manager.StageManager.create")
56+
def test_deploys_project_from_source(
57+
self, mock_create, mock_put_recursive, mock_connect, runner, dbt_project_path
58+
):
59+
60+
result = runner.invoke(
61+
["dbt", "deploy", "TEST_PIPELINE", f"--source={dbt_project_path}"]
62+
)
63+
64+
assert result.exit_code == 0, result.output
65+
assert (
66+
mock_connect.mocked_ctx.get_query()
67+
== "CREATE OR REPLACE DBT TEST_PIPELINE FROM @MockDatabase.MockSchema.dbt_TEST_PIPELINE_stage"
68+
)
69+
stage_fqn = FQN.from_string(f"dbt_TEST_PIPELINE_stage").using_context()
70+
mock_create.assert_called_once_with(stage_fqn, temporary=True)
71+
mock_put_recursive.assert_called_once_with(
72+
dbt_project_path, "@MockDatabase.MockSchema.dbt_TEST_PIPELINE_stage"
73+
)
74+
75+
def test_raises_when_dbt_project_is_not_available(
76+
self, dbt_project_path, mock_connect, runner
77+
):
78+
dbt_file = dbt_project_path / "dbt_project.yml"
79+
dbt_file.unlink()
80+
81+
result = runner.invoke(
82+
["dbt", "deploy", "TEST_PIPELINE", f"--source={dbt_project_path}"]
83+
)
84+
85+
assert result.exit_code == 1, result.output
86+
assert "dbt_project.yml does not exist in provided path." in result.output
87+
assert mock_connect.mocked_ctx.get_query() == ""
88+
89+
90+
class TestDBTExecute:
91+
@pytest.mark.parametrize(
92+
"args,expected_query",
93+
[
94+
pytest.param(
95+
[
96+
"dbt",
97+
"execute",
98+
"pipeline_name",
99+
"test",
100+
],
101+
"EXECUTE DBT pipeline_name test",
102+
id="simple-command",
103+
),
104+
pytest.param(
105+
[
106+
"dbt",
107+
"execute",
108+
"pipeline_name",
109+
"run",
110+
"-f",
111+
"--select @source:snowplow,tag:nightly models/export",
112+
],
113+
"EXECUTE DBT pipeline_name run -f --select @source:snowplow,tag:nightly models/export",
114+
id="with-dbt-options",
115+
),
116+
pytest.param(
117+
["dbt", "execute", "pipeline_name", "compile", "--vars '{foo:bar}'"],
118+
"EXECUTE DBT pipeline_name compile --vars '{foo:bar}'",
119+
id="with-dbt-vars",
120+
),
121+
pytest.param(
122+
[
123+
"dbt",
124+
"execute",
125+
"pipeline_name",
126+
"compile",
127+
"--format=TXT", # collision with CLI's option; unsupported option
128+
"-v", # collision with CLI's option
129+
"-h",
130+
"--debug",
131+
"--info",
132+
"--config-file=/",
133+
],
134+
"EXECUTE DBT pipeline_name compile --format=TXT -v -h --debug --info --config-file=/",
135+
id="with-dbt-conflicting-options",
136+
),
137+
pytest.param(
138+
[
139+
"dbt",
140+
"execute",
141+
"--format=JSON",
142+
"pipeline_name",
143+
"compile",
144+
],
145+
"EXECUTE DBT pipeline_name compile",
146+
id="with-cli-flag",
147+
),
148+
],
149+
)
150+
def test_dbt_execute(self, mock_connect, runner, args, expected_query):
151+
152+
result = runner.invoke(args)
153+
154+
assert result.exit_code == 0, result.output
155+
assert mock_connect.mocked_ctx.get_query() == expected_query

0 commit comments

Comments
 (0)