Skip to content

Commit be94439

Browse files
feat: [SNOW-1890085] dbt deploy: add support for dbt-version flag
1 parent 80f2fcf commit be94439

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

src/snowflake/cli/_plugins/dbt/commands.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def deploy_dbt(
6666
False,
6767
help="Overwrites conflicting files in the project, if any.",
6868
),
69+
dbt_version: Optional[str] = typer.Option(
70+
None,
71+
help="Version of dbt tool to be used. Taken from dbt_project.yml if not provided.",
72+
),
6973
**options,
7074
) -> CommandResult:
7175
"""
@@ -76,7 +80,9 @@ def deploy_dbt(
7680
path = Path.cwd()
7781
else:
7882
path = Path(source)
79-
return QueryResult(DBTManager().deploy(path.resolve(), name, force=force))
83+
return QueryResult(
84+
DBTManager().deploy(path.resolve(), name, dbt_version, force=force)
85+
)
8086

8187

8288
# `execute` is a pass through command group, meaning that all params after command should be passed over as they are,

src/snowflake/cli/_plugins/dbt/manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from __future__ import annotations
1616

1717
from pathlib import Path
18+
from typing import Optional
1819

20+
import yaml
1921
from click import ClickException
2022
from snowflake.cli._plugins.stage.manager import StageManager
2123
from snowflake.cli.api.console import cli_console
@@ -42,11 +44,23 @@ def list(self) -> SnowflakeCursor: # noqa: A003
4244
query = "SHOW DBT PROJECT"
4345
return self.execute_query(query)
4446

45-
def deploy(self, path: Path, name: FQN, force: bool) -> SnowflakeCursor:
47+
def deploy(
48+
self, path: Path, name: FQN, dbt_version: Optional[str], force: bool
49+
) -> SnowflakeCursor:
4650
# TODO: what to do with force?
4751
if not path.joinpath("dbt_project.yml").exists():
4852
raise ClickException(f"dbt_project.yml does not exist in provided path.")
4953

54+
if dbt_version is None:
55+
with path.joinpath("dbt_project.yml").open() as fd:
56+
dbt_project_config = yaml.safe_load(fd)
57+
try:
58+
dbt_version = dbt_project_config["version"]
59+
except (KeyError, TypeError):
60+
raise ClickException(
61+
f"dbt-version was not provided and is not available in dbt_project.yml"
62+
)
63+
5064
with cli_console.phase("Creating temporary stage"):
5165
stage_manager = StageManager()
5266
stage_fqn = FQN.from_string(f"dbt_{name}_stage").using_context()
@@ -58,7 +72,7 @@ def deploy(self, path: Path, name: FQN, force: bool) -> SnowflakeCursor:
5872
cli_console.step(f"Copied {len(results)} files")
5973

6074
with cli_console.phase("Creating DBT project"):
61-
query = f"CREATE OR REPLACE DBT PROJECT {name} FROM {stage_name}"
75+
query = f"CREATE OR REPLACE DBT PROJECT {name} FROM {stage_name} DBT_VERSION='{dbt_version}'"
6276
return self.execute_query(query)
6377

6478
def execute(self, dbt_command: str, name: str, *dbt_cli_args):

tests/dbt/test_dbt_commands.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest import mock
1818

1919
import pytest
20+
import yaml
2021
from snowflake.cli.api.identifiers import FQN
2122

2223

@@ -44,6 +45,8 @@ def dbt_project_path(self, tmp_path_factory):
4445
source_path = tmp_path_factory.mktemp("dbt_project")
4546
dbt_file = source_path / "dbt_project.yml"
4647
dbt_file.touch()
48+
with dbt_file.open(mode="w") as fd:
49+
yaml.dump({"version": "1.2.3"}, fd)
4750
yield source_path
4851

4952
@pytest.fixture
@@ -64,14 +67,35 @@ def test_deploys_project_from_source(
6467
assert result.exit_code == 0, result.output
6568
assert (
6669
mock_connect.mocked_ctx.get_query()
67-
== "CREATE OR REPLACE DBT PROJECT TEST_PIPELINE FROM @MockDatabase.MockSchema.dbt_TEST_PIPELINE_stage"
70+
== "CREATE OR REPLACE DBT PROJECT TEST_PIPELINE FROM @MockDatabase.MockSchema.dbt_TEST_PIPELINE_stage DBT_VERSION='1.2.3'"
6871
)
6972
stage_fqn = FQN.from_string(f"dbt_TEST_PIPELINE_stage").using_context()
7073
mock_create.assert_called_once_with(stage_fqn, temporary=True)
7174
mock_put_recursive.assert_called_once_with(
7275
dbt_project_path, "@MockDatabase.MockSchema.dbt_TEST_PIPELINE_stage"
7376
)
7477

78+
@mock.patch("snowflake.cli._plugins.dbt.manager.StageManager.put_recursive")
79+
@mock.patch("snowflake.cli._plugins.dbt.manager.StageManager.create")
80+
def test_dbt_version_from_option_has_precedence_over_file(
81+
self, _mock_create, _mock_put_recursive, mock_connect, runner, dbt_project_path
82+
):
83+
result = runner.invoke(
84+
[
85+
"dbt",
86+
"deploy",
87+
"TEST_PIPELINE",
88+
f"--source={dbt_project_path}",
89+
"--dbt-version=2.3.4",
90+
]
91+
)
92+
93+
assert result.exit_code == 0, result.output
94+
assert (
95+
mock_connect.mocked_ctx.get_query()
96+
== "CREATE OR REPLACE DBT PROJECT TEST_PIPELINE FROM @MockDatabase.MockSchema.dbt_TEST_PIPELINE_stage DBT_VERSION='2.3.4'"
97+
)
98+
7599
def test_raises_when_dbt_project_is_not_available(
76100
self, dbt_project_path, mock_connect, runner
77101
):
@@ -86,6 +110,24 @@ def test_raises_when_dbt_project_is_not_available(
86110
assert "dbt_project.yml does not exist in provided path." in result.output
87111
assert mock_connect.mocked_ctx.get_query() == ""
88112

113+
def test_raises_when_dbt_project_version_is_not_specified(
114+
self, dbt_project_path, mock_connect, runner
115+
):
116+
dbt_file = dbt_project_path / "dbt_project.yml"
117+
with dbt_file.open(mode="w") as fd:
118+
yaml.dump({}, fd)
119+
120+
result = runner.invoke(
121+
["dbt", "deploy", "TEST_PIPELINE", f"--source={dbt_project_path}"]
122+
)
123+
124+
assert result.exit_code == 1, result.output
125+
assert (
126+
"dbt-version was not provided and is not available in dbt_project.yml"
127+
in result.output
128+
)
129+
assert mock_connect.mocked_ctx.get_query() == ""
130+
89131

90132
class TestDBTExecute:
91133
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)