Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/snowflake/cli/_plugins/cicd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025 Snowflake Inc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put that module inside init plugin (../_plugins/init/cicd) - currently it is not used by any other context and is not providing new commands

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
74 changes: 74 additions & 0 deletions src/snowflake/cli/_plugins/cicd/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import enum
from typing import List, Optional, Type

from snowflake.cli.api.secure_path import SecurePath


class CIProviderChoices(str, enum.Enum):
GITHUB = "GITHUB"
GITLAB = "GITLAB"


class CIProvider:
NAME: str

@classmethod
def cleanup(cls, root: SecurePath) -> None:
raise NotImplementedError()

@classmethod
def from_choice(cls, choice: CIProviderChoices) -> "CIProvider":
return {
GithubProvider.NAME: GithubProvider,
GitLabProvider.NAME: GitLabProvider,
}[choice.name]()

@classmethod
def all(cls) -> List[Type["CIProvider"]]: # noqa: A003
return [GithubProvider, GitLabProvider]

def has_template(self, root_dir: SecurePath) -> bool:
raise NotImplementedError()

def copy(self, source: SecurePath, destination: SecurePath):
raise NotImplementedError()


class GithubProvider(CIProvider):
NAME = CIProviderChoices.GITHUB.name

@classmethod
def cleanup(cls, root_dir: SecurePath):
(root_dir / ".github").rmdir(recursive=True)

def has_template(self, root_dir: SecurePath) -> bool:
return (root_dir / ".github/workflows").exists()

def copy(self, source: SecurePath, destination: SecurePath) -> None:
(source / ".github").copy(destination.path, dirs_exist_ok=True)


class GitLabProvider(CIProvider):
NAME = CIProviderChoices.GITLAB.name

@classmethod
def cleanup(cls, root_dir: SecurePath):
(root_dir / ".gitlab-ci.yml").unlink(missing_ok=True)

def has_template(self, root_dir: SecurePath) -> bool:
return (root_dir / ".gitlab-ci.yml").exists()

def copy(self, source: SecurePath, destination: SecurePath) -> None:
if (destination / ".gitlab-ci.yml").exists():
(destination / ".gitlab-ci.yml").unlink()
(source / ".gitlab-ci.yml").move(destination.path)


class CIProviderManager:
@staticmethod
def project_post_gen_cleanup(
selected_provider: Optional[CIProvider], template_root: SecurePath
):
for provider_cls in CIProvider.all():
if selected_provider and not isinstance(selected_provider, provider_cls):
provider_cls.cleanup(template_root)
24 changes: 24 additions & 0 deletions src/snowflake/cli/_plugins/cicd/plugin_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2024 Snowflake Inc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover file

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# from snowflake.cli._plugins.cicd import commands

#
# @plugin_hook_impl
# def command_spec():
# return CommandSpec(
# parent_command_path=SNOWCLI_ROOT_COMMAND_PATH,
# command_type=CommandType.COMMAND_GROUP,
# typer_instance=commands.app.create_instance(),
# )
95 changes: 82 additions & 13 deletions src/snowflake/cli/_plugins/init/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import yaml
from click import ClickException
from snowflake.cli.__about__ import VERSION
from snowflake.cli._plugins.cicd.manager import (
CIProvider,
CIProviderChoices,
CIProviderManager,
)
from snowflake.cli.api.commands.flags import (
NoInteractiveOption,
variables_option,
Expand Down Expand Up @@ -72,6 +77,17 @@ def _path_argument_callback(path: str) -> str:
"--template-source",
help=f"local path to template directory or URL to git repository with templates.",
)
CIProviderOption = typer.Option(
None,
"--ci-provider",
help=f"CI provider to generate workflow for.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
help=f"CI provider to generate workflow for.",
help=f"generate CI/CD workflow for given provider.",

case_sensitive=True,
)
CITemplateSourceOption = typer.Option(
None,
"--ci-template-source",
help=f"local path to template directory or URL to git repository with ci/cd templates.",
)
VariablesOption = variables_option(
"String in `key=value` format. Provided variables will not be prompted for."
)
Expand Down Expand Up @@ -191,6 +207,8 @@ def init(
path: str = PathArgument,
template: Optional[str] = TemplateOption,
template_source: Optional[str] = SourceOption,
ci_provider: Optional[CIProviderChoices] = CIProviderOption,
ci_template_source: Optional[str] = CITemplateSourceOption,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we validate that "ci_provider" must be defined if "ci_template_source" is procided?

variables: Optional[List[str]] = VariablesOption,
no_interactive: bool = NoInteractiveOption,
**options,
Expand All @@ -201,30 +219,30 @@ def init(
variables_from_flags = {
v.key: v.value for v in parse_key_value_variables(variables)
}
is_remote = any(
template_source.startswith(prefix) for prefix in ["git@", "http://", "https://"] # type: ignore
)
args_error_msg = f"Check whether {TemplateOption.param_decls[0]} and {SourceOption.param_decls[0]} arguments are correct."

# copy/download template into tmpdir, so it is going to be removed in case command ends with an error
with SecurePath.temporary_directory() as tmpdir:
if is_remote:
template_root = _fetch_remote_template(
url=template_source, path=template, destination=tmpdir # type: ignore
)
else:
template_root = _fetch_local_template(
template_source=SecurePath(template_source),
path=template,
destination=tmpdir,
)
assert isinstance(template_source, str)
template_root = _fetch_template(template_source, template, tmpdir)

template_metadata = _read_template_metadata(
template_root, args_error_msg=args_error_msg
)
if template_metadata.minimum_cli_version:
_validate_cli_version(template_metadata.minimum_cli_version)

if ci_provider:
ci_provider_instance = CIProvider.from_choice(ci_provider)
clone(
ci_provider_instance,
ci_template_source,
template_metadata,
template_root,
)
else:
ci_provider_instance = None

variable_values = _determine_variable_values(
variables_metadata=template_metadata.variables,
variables_from_flags=variables_from_flags,
Expand All @@ -242,7 +260,58 @@ def init(
data=variable_values,
)
_remove_template_metadata_file(template_root)
post_generate(template_root, ci_provider_instance)
SecurePath(path).parent.mkdir(exist_ok=True, parents=True)
template_root.copy(path)

return MessageResult(f"Initialized the new project in {path}")


def clone(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"_fetch_cicd_template" would be a better name - "clone" assumes the template lives in remote repo

ci_provider_instance: CIProvider,
ci_template_source: Optional[str],
template_metadata: Template,
template_root: SecurePath,
):
if ci_template_source is not None:
with SecurePath.temporary_directory() as cicd_tmpdir:
cicd_template_root = _fetch_template(ci_template_source, None, cicd_tmpdir)
ci_provider_instance.copy(cicd_template_root, template_root)
ci_template_metadata = _read_template_metadata(
cicd_template_root,
args_error_msg="template.yml is required for --ci-template-source.",
)
template_metadata.merge(ci_template_metadata)

elif ci_provider_instance.has_template(template_root):
pass # template has ci files
else:
raise ClickException(
f"Template for {ci_provider_instance.NAME} not provided and not configured on selected template."
)


def _fetch_template(
template_source: str, template: Optional[str], tmpdir: SecurePath
) -> SecurePath:
if _is_remote_source(template_source):
template_root = _fetch_remote_template(
url=template_source, path=template, destination=tmpdir # type: ignore
)
else:
template_root = _fetch_local_template(
template_source=SecurePath(template_source),
path=template,
destination=tmpdir,
)
return template_root
Comment on lines +297 to +307
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if _is_remote_source(template_source):
template_root = _fetch_remote_template(
url=template_source, path=template, destination=tmpdir # type: ignore
)
else:
template_root = _fetch_local_template(
template_source=SecurePath(template_source),
path=template,
destination=tmpdir,
)
return template_root
if _is_remote_source(template_source):
return _fetch_remote_template(
url=template_source, path=template, destination=tmpdir # type: ignore
)
return _fetch_local_template(
template_source=SecurePath(template_source),
path=template,
destination=tmpdir,
)



def _is_remote_source(template_source: str) -> bool:
return any(
template_source.startswith(prefix) for prefix in ["git@", "http://", "https://"] # type: ignore
)


def post_generate(template_root: SecurePath, ci_provider: Optional[CIProvider]):
CIProviderManager.project_post_gen_cleanup(ci_provider, template_root)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? What we assume about Ci/CD templates?

30 changes: 29 additions & 1 deletion src/snowflake/cli/api/project/schemas/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

import typer
from click import ClickException
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from snowflake.cli.api.exceptions import InvalidTemplate
from snowflake.cli.api.secure_path import SecurePath


class TemplateVariable(BaseModel):
model_config = ConfigDict(frozen=True)

name: str = Field(..., title="Variable identifier")
type: Optional[Literal["string", "float", "int"]] = Field( # noqa: A003
title="Type of the variable", default=None
Expand Down Expand Up @@ -64,6 +66,32 @@ def __init__(self, template_root: SecurePath, **kwargs):
super().__init__(**kwargs)
self._validate_files_exist(template_root)

def merge(self, other: Template):
if not isinstance(other, Template):
raise ClickException(f"Can not merge template with {type(other)}")

errors = []
if self.minimum_cli_version != other.minimum_cli_version:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just using max here?

errors.append(
f"minimum_cli_versions do not match: {self.minimum_cli_version} != {other.minimum_cli_version}"
)
variable_map = {variable.name: variable for variable in self.variables}
for other_variable in other.variables:
if self_variable := variable_map.get(other_variable.name):
for attr in ["type", "prompt", "default"]:
if getattr(self_variable, attr) != getattr(other_variable, attr):
errors.append(
f"Conflicting variable definitions: '{self_variable.name}' has different values for attribute '{attr}': '{getattr(self_variable, attr)}' != '{getattr(other_variable, attr)}'"
)
if errors:
error_str = "\n\t" + "\n\t".join(error for error in errors)
raise ClickException(
f"Could not merge templates. Following errors found:{error_str}"
)
self.files_to_render = list(set(self.files_to_render + other.files_to_render))
self.variables = list(set(self.variables + other.variables))
return self

def _validate_files_exist(self, template_root: SecurePath) -> None:
for path_in_template in self.files_to_render:
full_path = template_root / path_in_template
Expand Down
Loading