diff --git a/pyproject.toml b/pyproject.toml index b3a5101d287..b322a4c9681 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "zenml" version = "0.85.0" -packages = [{ include = "zenml", from = "src" }] +packages = [{ include = "zenml", from = "src" }, { include = "zenml_cli", from = "src" }] description = "ZenML: Write production-ready ML code." authors = ["ZenML GmbH "] readme = "README.md" @@ -38,7 +38,7 @@ exclude = [ include = ["src/zenml", "*.txt", "*.sh", "*.md"] [tool.poetry.scripts] -zenml = "zenml.cli.cli:cli" +zenml = "zenml_cli:cli" [tool.poetry.dependencies] alembic = { version = ">=1.8.1,<=1.15.2" } diff --git a/src/zenml/cli/artifact.py b/src/zenml/cli/artifact.py index 751d85f7b49..db66df31e45 100644 --- a/src/zenml/cli/artifact.py +++ b/src/zenml/cli/artifact.py @@ -13,18 +13,23 @@ # permissions and limitations under the License. """CLI functionality to interact with artifacts.""" -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional import click from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli +from zenml.cli.utils import ( + list_options, +) from zenml.client import Client +from zenml.console import console from zenml.enums import CliCategories from zenml.logger import get_logger -from zenml.models import ArtifactFilter, ArtifactVersionFilter -from zenml.models.v2.core.artifact import ArtifactResponse -from zenml.models.v2.core.artifact_version import ArtifactVersionResponse +from zenml.models import ( + ArtifactFilter, + ArtifactVersionFilter, +) from zenml.utils.pagination_utils import depaginate logger = get_logger(__name__) @@ -35,25 +40,39 @@ def artifact() -> None: """Commands for interacting with artifacts.""" -@cli_utils.list_options(ArtifactFilter) +@list_options( + ArtifactFilter, + default_columns=["id", "name", "latest_version_name", "user", "created"], +) @artifact.command("list", help="List all artifacts.") -def list_artifacts(**kwargs: Any) -> None: +def list_artifacts(output_format: str, columns: str, **kwargs: Any) -> None: """List all artifacts. Args: - **kwargs: Keyword arguments to filter artifacts by. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter artifacts by. """ - artifacts = Client().list_artifacts(**kwargs) - - if not artifacts: - cli_utils.declare("No artifacts found.") - return - - to_print = [] - for artifact in artifacts: - to_print.append(_artifact_to_print(artifact)) + with console.status("Listing artifacts..."): + artifacts = Client().list_artifacts(**kwargs) + + artifact_list = [] + for artifact in artifacts.items: + artifact_data = cli_utils.prepare_response_data(artifact) + artifact_data.update( + { + "latest_version_name": artifact.latest_version_name, + "latest_version_id": artifact.latest_version_id, + } + ) + artifact_list.append(artifact_data) - cli_utils.print_table(to_print) + cli_utils.handle_output( + artifact_list, + pagination_info=artifacts.pagination_info, + columns=columns, + output_format=output_format, + ) @artifact.command("update", help="Update an artifact.") @@ -115,25 +134,42 @@ def version() -> None: """Commands for interacting with artifact versions.""" -@cli_utils.list_options(ArtifactVersionFilter) +@list_options( + ArtifactVersionFilter, + default_columns=["id", "name", "version", "type", "user", "created"], +) @version.command("list", help="List all artifact versions.") -def list_artifact_versions(**kwargs: Any) -> None: +def list_artifact_versions( + output_format: str, columns: str, **kwargs: Any +) -> None: """List all artifact versions. Args: - **kwargs: Keyword arguments to filter artifact versions by. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter artifact versions by. """ - artifact_versions = Client().list_artifact_versions(**kwargs) - - if not artifact_versions: - cli_utils.declare("No artifact versions found.") - return + with console.status("Listing artifact versions..."): + artifact_versions = Client().list_artifact_versions(**kwargs) - to_print = [] - for artifact_version in artifact_versions: - to_print.append(_artifact_version_to_print(artifact_version)) + artifact_version_list = [] + for artifact_version in artifact_versions.items: + artifact_version_data = cli_utils.prepare_response_data( + artifact_version + ) + artifact_version_data.update( + { + "name": artifact_version.artifact.name, + } + ) + artifact_version_list.append(artifact_version_data) - cli_utils.print_table(to_print) + cli_utils.handle_output( + artifact_version_list, + pagination_info=artifact_versions.pagination_info, + columns=columns, + output_format=output_format, + ) @version.command("update", help="Update an artifact version.") @@ -294,28 +330,3 @@ def prune_artifacts( f"Failed to delete artifact version {unused_artifact_version.id}: {str(e)}" ) cli_utils.declare("All unused artifacts and artifact versions deleted.") - - -def _artifact_version_to_print( - artifact_version: ArtifactVersionResponse, -) -> Dict[str, Any]: - return { - "id": artifact_version.id, - "name": artifact_version.artifact.name, - "version": artifact_version.version, - "uri": artifact_version.uri, - "type": artifact_version.type, - "materializer": artifact_version.materializer, - "data_type": artifact_version.data_type, - "tags": [t.name for t in artifact_version.tags], - } - - -def _artifact_to_print( - artifact_version: ArtifactResponse, -) -> Dict[str, Any]: - return { - "id": artifact_version.id, - "name": artifact_version.name, - "tags": [t.name for t in artifact_version.tags], - } diff --git a/src/zenml/cli/authorized_device.py b/src/zenml/cli/authorized_device.py index 80701b263d9..fc550edb6d5 100644 --- a/src/zenml/cli/authorized_device.py +++ b/src/zenml/cli/authorized_device.py @@ -56,27 +56,37 @@ def describe_authorized_device(id_or_prefix: str) -> None: ) +@list_options( + OAuthDeviceFilter, + default_columns=["status", "ip_address", "hostname", "os", "created"], +) @authorized_device.command( "list", help="List all authorized devices for the current user." ) -@list_options(OAuthDeviceFilter) -def list_authorized_devices(**kwargs: Any) -> None: +def list_authorized_devices( + output_format: str, columns: str, **kwargs: Any +) -> None: """List all authorized devices. Args: - **kwargs: Keyword arguments to filter authorized devices. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter authorized devices. """ - with console.status("Listing authorized devices...\n"): + with console.status("Listing authorized devices..."): devices = Client().list_authorized_devices(**kwargs) - if not devices.items: - cli_utils.declare("No authorized devices found for this filter.") - return + device_list = [] + for device in devices.items: + device_data = cli_utils.prepare_response_data(device) + device_list.append(device_data) - cli_utils.print_pydantic_models( - devices, - columns=["id", "status", "ip_address", "hostname", "os"], - ) + cli_utils.handle_output( + device_list, + pagination_info=devices.pagination_info, + columns=columns, + output_format=output_format, + ) @authorized_device.command("lock") diff --git a/src/zenml/cli/code_repository.py b/src/zenml/cli/code_repository.py index 0692476f23b..c44a623301d 100644 --- a/src/zenml/cli/code_repository.py +++ b/src/zenml/cli/code_repository.py @@ -19,7 +19,9 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import list_options +from zenml.cli.utils import ( + list_options, +) from zenml.client import Client from zenml.code_repositories import BaseCodeRepository from zenml.config.source import Source @@ -188,25 +190,34 @@ def describe_code_repository(name_id_or_prefix: str) -> None: ) +@list_options( + CodeRepositoryFilter, default_columns=["name", "type", "url", "created"] +) @code_repository.command("list", help="List all connected code repositories.") -@list_options(CodeRepositoryFilter) -def list_code_repositories(**kwargs: Any) -> None: +def list_code_repositories( + output_format: str, columns: str, **kwargs: Any +) -> None: """List all connected code repositories. Args: - **kwargs: Keyword arguments to filter code repositories. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter code repositories. """ - with console.status("Listing code repositories...\n"): + with console.status("Listing code repositories..."): repos = Client().list_code_repositories(**kwargs) - if not repos.items: - cli_utils.declare("No code repositories found for this filter.") - return + repo_list = [] + for repo in repos.items: + repo_data = cli_utils.prepare_response_data(repo) + repo_list.append(repo_data) - cli_utils.print_pydantic_models( - repos, - exclude_columns=["created", "updated", "user", "project"], - ) + cli_utils.handle_output( + repo_list, + pagination_info=repos.pagination_info, + columns=columns, + output_format=output_format, + ) @code_repository.command( diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 3bf12beb583..054bec08230 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -19,7 +19,11 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli +from zenml.cli.utils import ( + list_options, +) from zenml.client import Client +from zenml.console import console from zenml.enums import CliCategories, ModelStages from zenml.exceptions import EntityExistsError from zenml.logger import get_logger @@ -27,8 +31,10 @@ ModelFilter, ModelResponse, ModelVersionArtifactFilter, + ModelVersionArtifactResponse, ModelVersionFilter, ModelVersionPipelineRunFilter, + ModelVersionPipelineRunResponse, ModelVersionResponse, ) from zenml.utils.dict_utils import remove_none_values @@ -36,39 +42,108 @@ logger = get_logger(__name__) -def _model_to_print(model: ModelResponse) -> Dict[str, Any]: +def _generate_model_data(model: ModelResponse) -> Dict[str, Any]: + """Generate additional data for model display. + + Args: + model: The model response. + + Returns: + The additional data for the model. + """ return { - "id": model.id, - "name": model.name, - "latest_version": model.latest_version_name, - "description": model.description, - "tags": [t.name for t in model.tags], - "save_to_registry": ":white_check_mark:" - if model.save_models_to_registry + "latest_version_name": model.latest_version_name, + "latest_version_id": model.latest_version_id, + } + + +def _generate_model_version_data( + model_version: ModelVersionResponse, output_format: str +) -> Dict[str, Any]: + """Generate additional data for model version display. + + Args: + model_version: The model version response. + output_format: The output format. + + Returns: + The additional data for the model version. + """ + # Get stage value for formatting + stage_value = ( + str(model_version.stage).lower() if model_version.stage else "" + ) + model_name = model_version.model.name + version_name = model_version.name + stage_display = str(model_version.stage) if model_version.stage else "" + + # Apply stage-based formatting only for table output + if output_format == "table": + if stage_value == "production": + # Format with green dot at beginning, green model name and version + formatted_model = ( + f"[green]●[/green] [bold green]{model_name}[/bold green]" + ) + formatted_version = f"[bold green]{version_name}[/bold green]" + formatted_stage = f"[bold green]{stage_display}[/bold green]" + elif stage_value == "staging": + # Format with orange dot at beginning, orange name and version + formatted_model = f"[bright_yellow]●[/bright_yellow] [bright_yellow]{model_name}[/bright_yellow]" + formatted_version = ( + f"[bright_yellow]{version_name}[/bright_yellow]" + ) + formatted_stage = f"[bright_yellow]{stage_display}[/bright_yellow]" + else: + # For other stages (development, archived, etc.), keep default format + formatted_model = model_name + formatted_version = version_name + formatted_stage = stage_display + else: + # For non-table formats, use plain text + formatted_model = model_name + formatted_version = version_name + formatted_stage = stage_display + + return { + "model": formatted_model, + "version": formatted_version, + "stage": formatted_stage, + "tags": ", ".join(tag.name for tag in model_version.tags) + if model_version.tags else "", - "use_cases": model.use_cases, - "audience": model.audience, - "limitations": model.limitations, - "trade_offs": model.trade_offs, - "ethics": model.ethics, - "license": model.license, - "updated": model.updated.date(), + "updated": model_version.updated, } -def _model_version_to_print( - model_version: ModelVersionResponse, +def _generate_model_version_artifact_data( + model_version_artifact: ModelVersionArtifactResponse, ) -> Dict[str, Any]: + """Generate additional data for model version artifact display. + + Args: + model_version_artifact: The model version artifact response. + + Returns: + The additional data for the model version artifact. + """ + return { + "artifact_version": model_version_artifact.artifact_version.id, + } + + +def _generate_model_version_pipeline_run_data( + model_version_pipeline_run: ModelVersionPipelineRunResponse, +) -> Dict[str, Any]: + """Generate additional data for model version pipeline run display. + + Args: + model_version_pipeline_run: The model version pipeline run response. + + Returns: + The additional data for the model version pipeline run. + """ return { - "id": model_version.id, - "model": model_version.model.name, - "name": model_version.name, - "number": model_version.number, - "description": model_version.description, - "stage": model_version.stage, - "run_metadata": model_version.run_metadata, - "tags": [t.name for t in model_version.tags], - "updated": model_version.updated.date(), + "pipeline_run": model_version_pipeline_run.pipeline_run.id, } @@ -77,23 +152,39 @@ def model() -> None: """Interact with models and model versions in the Model Control Plane.""" -@cli_utils.list_options(ModelFilter) +@list_options( + ModelFilter, + default_columns=[ + "name", + "latest_version_name", + "latest_version_id", + "updated", + ], +) @model.command("list", help="List models with filter.") -def list_models(**kwargs: Any) -> None: +def list_models(output_format: str, columns: str, **kwargs: Any) -> None: """List models with filter in the Model Control Plane. Args: - **kwargs: Keyword arguments to filter models. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter models. """ - models = Client().list_models(**kwargs) - - if not models: - cli_utils.declare("No models found.") - return - to_print = [] - for model in models: - to_print.append(_model_to_print(model)) - cli_utils.print_table(to_print) + with console.status("Listing models..."): + models = Client().list_models(**kwargs) + + model_list = [] + for model in models.items: + model_data = cli_utils.prepare_response_data(model) + model_data.update(_generate_model_data(model)) + model_list.append(model_data) + + cli_utils.handle_output( + model_list, + pagination_info=models.pagination_info, + columns=columns, + output_format=output_format, + ) @model.command("register", help="Register a new model.") @@ -195,7 +286,7 @@ def register_model( registry. """ try: - model = Client().create_model( + Client().create_model( **remove_none_values( dict( name=name, @@ -214,8 +305,6 @@ def register_model( except (EntityExistsError, ValueError) as e: cli_utils.error(str(e)) - cli_utils.print_table([_model_to_print(model)]) - @model.command("update", help="Update an existing model.") @click.argument("model_name_or_id") @@ -344,9 +433,7 @@ def update_model( save_models_to_registry=save_models_to_registry, ) ) - model = Client().update_model(model_name_or_id=model_id, **update_dict) - - cli_utils.print_table([_model_to_print(model)]) + Client().update_model(model_name_or_id=model_id, **update_dict) @model.command("delete", help="Delete an existing model.") @@ -390,25 +477,38 @@ def version() -> None: """Interact with model versions in the Model Control Plane.""" -@cli_utils.list_options(ModelVersionFilter) +@list_options( + ModelVersionFilter, + default_columns=["model", "version", "stage", "tags", "updated"], +) @version.command("list", help="List model versions with filter.") -def list_model_versions(**kwargs: Any) -> None: +def list_model_versions( + output_format: str, columns: str, **kwargs: Any +) -> None: """List model versions with filter in the Model Control Plane. Args: - **kwargs: Keyword arguments to filter models. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter model versions. """ - model_versions = Client().list_model_versions(**kwargs) - - if not model_versions: - cli_utils.declare("No model versions found.") - return - - to_print = [] - for model_version in model_versions: - to_print.append(_model_version_to_print(model_version)) + with console.status("Listing model versions..."): + model_versions = Client().list_model_versions(**kwargs) + + model_version_list = [] + for model_version in model_versions.items: + model_version_data = cli_utils.prepare_response_data(model_version) + model_version_data.update( + _generate_model_version_data(model_version, output_format) + ) + model_version_list.append(model_version_data) - cli_utils.print_table(to_print) + cli_utils.handle_output( + model_version_list, + pagination_info=model_versions.pagination_info, + columns=columns, + output_format=output_format, + ) @version.command("update", help="Update an existing model version stage.") @@ -496,8 +596,6 @@ def update_model_version( ) except RuntimeError: if not force: - cli_utils.print_table([_model_version_to_print(model_version)]) - confirmation = cli_utils.confirmation( "Are you sure you want to change the status of model " f"version '{model_version_name_or_number_or_id}' to " @@ -517,7 +615,6 @@ def update_model_version( force=True, description=description, ) - cli_utils.print_table([_model_version_to_print(model_version)]) @version.command("delete", help="Delete an existing model version.") @@ -565,66 +662,17 @@ def delete_model_version( ) -def _print_artifacts_links_generic( - model_name_or_id: str, - model_version_name_or_number_or_id: Optional[str] = None, - only_data_artifacts: bool = False, - only_deployment_artifacts: bool = False, - only_model_artifacts: bool = False, - **kwargs: Any, -) -> None: - """Generic method to print artifacts links. - - Args: - model_name_or_id: The ID or name of the model containing version. - model_version_name_or_number_or_id: The name, number or ID of the model version. - only_data_artifacts: If set, only print data artifacts. - only_deployment_artifacts: If set, only print deployment artifacts. - only_model_artifacts: If set, only print model artifacts. - **kwargs: Keyword arguments to filter models. - """ - model_version = Client().get_model_version( - model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, - ) - type_ = ( - "data artifacts" - if only_data_artifacts - else "deployment artifacts" - if only_deployment_artifacts - else "model artifacts" - ) - - links = Client().list_model_version_artifact_links( - model_version_id=model_version.id, - only_data_artifacts=only_data_artifacts, - only_deployment_artifacts=only_deployment_artifacts, - only_model_artifacts=only_model_artifacts, - **kwargs, - ) - - if not links: - cli_utils.declare(f"No {type_} linked to the model version found.") - return - - cli_utils.title( - f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:" - ) - cli_utils.print_pydantic_models( - links, - columns=["artifact_version", "created"], - ) - - @model.command( "data_artifacts", help="List data artifacts linked to a model version.", ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionArtifactFilter) +@list_options(ModelVersionArtifactFilter) def list_model_version_data_artifacts( model_name: str, + output_format: str, + columns: str, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -632,15 +680,35 @@ def list_model_version_data_artifacts( Args: model_name: The ID or name of the model containing version. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. model_version: The name, number or ID of the model version. If not provided, the latest version is used. - **kwargs: Keyword arguments to filter models. + kwargs: Keyword arguments to filter models. """ - _print_artifacts_links_generic( + model_version_obj = Client().get_model_version( model_name_or_id=model_name, model_version_name_or_number_or_id=model_version, - only_data_artifacts=True, - **kwargs, + ) + + with console.status("Listing data artifacts..."): + links = Client().list_model_version_artifact_links( + model_version_id=model_version_obj.id, + only_data_artifacts=True, + **kwargs, + ) + + artifact_list = [] + for link in links.items: + artifact_data = cli_utils.prepare_response_data(link) + artifact_data.update(_generate_model_version_artifact_data(link)) + artifact_list.append(artifact_data) + + cli_utils.handle_output( + artifact_list, + pagination_info=links.pagination_info, + columns=columns, + output_format=output_format, ) @@ -650,9 +718,11 @@ def list_model_version_data_artifacts( ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionArtifactFilter) +@list_options(ModelVersionArtifactFilter) def list_model_version_model_artifacts( model_name: str, + output_format: str, + columns: str, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -660,15 +730,35 @@ def list_model_version_model_artifacts( Args: model_name: The ID or name of the model containing version. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. model_version: The name, number or ID of the model version. If not provided, the latest version is used. - **kwargs: Keyword arguments to filter models. + kwargs: Keyword arguments to filter models. """ - _print_artifacts_links_generic( + model_version_obj = Client().get_model_version( model_name_or_id=model_name, model_version_name_or_number_or_id=model_version, - only_model_artifacts=True, - **kwargs, + ) + + with console.status("Listing model artifacts..."): + links = Client().list_model_version_artifact_links( + model_version_id=model_version_obj.id, + only_model_artifacts=True, + **kwargs, + ) + + artifact_list = [] + for link in links.items: + artifact_data = cli_utils.prepare_response_data(link) + artifact_data.update(_generate_model_version_artifact_data(link)) + artifact_list.append(artifact_data) + + cli_utils.handle_output( + artifact_list, + pagination_info=links.pagination_info, + columns=columns, + output_format=output_format, ) @@ -678,9 +768,11 @@ def list_model_version_model_artifacts( ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionArtifactFilter) +@list_options(ModelVersionArtifactFilter) def list_model_version_deployment_artifacts( model_name: str, + output_format: str, + columns: str, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -688,15 +780,35 @@ def list_model_version_deployment_artifacts( Args: model_name: The ID or name of the model containing version. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. model_version: The name, number or ID of the model version. If not provided, the latest version is used. - **kwargs: Keyword arguments to filter models. + kwargs: Keyword arguments to filter models. """ - _print_artifacts_links_generic( + model_version_obj = Client().get_model_version( model_name_or_id=model_name, model_version_name_or_number_or_id=model_version, - only_deployment_artifacts=True, - **kwargs, + ) + + with console.status("Listing deployment artifacts..."): + links = Client().list_model_version_artifact_links( + model_version_id=model_version_obj.id, + only_deployment_artifacts=True, + **kwargs, + ) + + artifact_list = [] + for link in links.items: + artifact_data = cli_utils.prepare_response_data(link) + artifact_data.update(_generate_model_version_artifact_data(link)) + artifact_list.append(artifact_data) + + cli_utils.handle_output( + artifact_list, + pagination_info=links.pagination_info, + columns=columns, + output_format=output_format, ) @@ -706,9 +818,11 @@ def list_model_version_deployment_artifacts( ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionPipelineRunFilter) +@list_options(ModelVersionPipelineRunFilter) def list_model_version_pipeline_runs( model_name: str, + output_format: str, + columns: str, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -716,25 +830,32 @@ def list_model_version_pipeline_runs( Args: model_name: The ID or name of the model containing version. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. model_version: The name, number or ID of the model version. If not provided, the latest version is used. - **kwargs: Keyword arguments to filter runs. + kwargs: Keyword arguments to filter runs. """ model_version_response_model = Client().get_model_version( model_name_or_id=model_name, model_version_name_or_number_or_id=model_version, ) - runs = Client().list_model_version_pipeline_run_links( - model_version_id=model_version_response_model.id, - **kwargs, - ) - - if not runs: - cli_utils.declare("No pipeline runs attached to model version found.") - return + with console.status("Listing pipeline runs..."): + runs = Client().list_model_version_pipeline_run_links( + model_version_id=model_version_response_model.id, + **kwargs, + ) - cli_utils.title( - f"Pipeline runs linked to the model version `{model_version_response_model.name}[{model_version_response_model.number}]`:" + run_list = [] + for run in runs.items: + run_data = cli_utils.prepare_response_data(run) + run_data.update(_generate_model_version_pipeline_run_data(run)) + run_list.append(run_data) + + cli_utils.handle_output( + run_list, + pagination_info=runs.pagination_info, + columns=columns, + output_format=output_format, ) - cli_utils.print_pydantic_models(runs) diff --git a/src/zenml/cli/pipeline.py b/src/zenml/cli/pipeline.py index cd1c01af985..dcfccb74414 100644 --- a/src/zenml/cli/pipeline.py +++ b/src/zenml/cli/pipeline.py @@ -30,10 +30,14 @@ from zenml.models import ( PipelineBuildBase, PipelineBuildFilter, + PipelineBuildResponse, PipelineFilter, + PipelineResponse, PipelineRunFilter, + PipelineRunResponse, PipelineSnapshotFilter, ScheduleFilter, + ScheduleResponse, ) from zenml.pipelines.pipeline_definition import Pipeline from zenml.utils import run_utils, source_utils, uuid_utils @@ -42,6 +46,81 @@ logger = get_logger(__name__) +def _generate_pipeline_data(pipeline: PipelineResponse) -> Dict[str, Any]: + """Generate additional data for pipeline display. + + Args: + pipeline: The pipeline response. + + Returns: + The additional data for the pipeline. + """ + return { + "latest_run_status": pipeline.latest_run_status or "", + "latest_run_id": pipeline.latest_run_id or "", + "tags": ", ".join(tag.name for tag in pipeline.tags) + if pipeline.tags + else "", + "created": pipeline.created, + } + + +def _generate_schedule_data(schedule: ScheduleResponse) -> Dict[str, Any]: + """Generate additional data for schedule display. + + Args: + schedule: The schedule response. + + Returns: + The additional data for the schedule. + """ + return { + "active": schedule.active, + "cron_expression": schedule.cron_expression, + } + + +def _generate_pipeline_run_data( + pipeline_run: PipelineRunResponse, +) -> Dict[str, Any]: + """Generate additional data for pipeline run display. + + Args: + pipeline_run: The pipeline run response. + + Returns: + The additional data for the pipeline run. + """ + return { + "pipeline": pipeline_run.pipeline.name + if pipeline_run.pipeline + else "", + "stack": pipeline_run.stack.name if pipeline_run.stack else "", + } + + +def _generate_pipeline_build_data( + pipeline_build: PipelineBuildResponse, +) -> Dict[str, Any]: + """Generate additional data for pipeline build display. + + Args: + pipeline_build: The pipeline build response. + + Returns: + The additional data for the pipeline build. + """ + return { + "pipeline_name": pipeline_build.pipeline.name + if pipeline_build.pipeline + else "", + "zenml_version": pipeline_build.zenml_version, + "stack_name": pipeline_build.stack.name + if pipeline_build.stack + else "", + } + + def _import_pipeline(source: str) -> Pipeline: """Import a pipeline. @@ -371,26 +450,41 @@ def create_run_template( cli_utils.declare(f"Created run template `{template.id}`.") +@list_options( + PipelineFilter, + default_columns=[ + "id", + "name", + "latest_run_status", + "latest_run_id", + "tags", + "created", + ], +) @pipeline.command("list", help="List all registered pipelines.") -@list_options(PipelineFilter) -def list_pipelines(**kwargs: Any) -> None: +def list_pipelines(output_format: str, columns: str, **kwargs: Any) -> None: """List all registered pipelines. Args: - **kwargs: Keyword arguments to filter pipelines. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter pipelines. """ - client = Client() - with console.status("Listing pipelines...\n"): - pipelines = client.list_pipelines(**kwargs) - - if not pipelines.items: - cli_utils.declare("No pipelines found for this filter.") - return - - cli_utils.print_pydantic_models( - pipelines, - exclude_columns=["id", "created", "updated", "user", "project"], - ) + with console.status("Listing pipelines..."): + pipelines = Client().list_pipelines(**kwargs) + + pipeline_list = [] + for pipeline in pipelines.items: + pipeline_data = cli_utils.prepare_response_data(pipeline) + pipeline_data.update(_generate_pipeline_data(pipeline)) + pipeline_list.append(pipeline_data) + + cli_utils.handle_output( + pipeline_list, + pagination_info=pipelines.pagination_info, + columns=columns, + output_format=output_format, + ) @pipeline.command("delete") @@ -436,25 +530,33 @@ def schedule() -> None: """Commands for pipeline run schedules.""" +@list_options( + ScheduleFilter, + default_columns=["name", "active", "cron_expression", "user", "created"], +) @schedule.command("list", help="List all pipeline schedules.") -@list_options(ScheduleFilter) -def list_schedules(**kwargs: Any) -> None: +def list_schedules(output_format: str, columns: str, **kwargs: Any) -> None: """List all pipeline schedules. Args: - **kwargs: Keyword arguments to filter schedules. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter schedules. """ - client = Client() - - schedules = client.list_schedules(**kwargs) - - if not schedules: - cli_utils.declare("No schedules found for this filter.") - return - - cli_utils.print_pydantic_models( - schedules, - exclude_columns=["id", "created", "updated", "user", "project"], + with console.status("Listing schedules..."): + schedules = Client().list_schedules(**kwargs) + + schedule_list = [] + for schedule in schedules.items: + schedule_data = cli_utils.prepare_response_data(schedule) + schedule_data.update(_generate_schedule_data(schedule)) + schedule_list.append(schedule_data) + + cli_utils.handle_output( + schedule_list, + pagination_info=schedules.pagination_info, + columns=columns, + output_format=output_format, ) @@ -528,27 +630,44 @@ def runs() -> None: """Commands for pipeline runs.""" +@list_options( + PipelineRunFilter, + default_columns=[ + "id", + "name", + "status", + "pipeline", + "user", + "stack", + "created", + ], +) @runs.command("list", help="List all registered pipeline runs.") -@list_options(PipelineRunFilter) -def list_pipeline_runs(**kwargs: Any) -> None: +def list_pipeline_runs( + output_format: str, columns: str, **kwargs: Any +) -> None: """List all registered pipeline runs for the filter. Args: - **kwargs: Keyword arguments to filter pipeline runs. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter pipeline runs. """ - client = Client() - try: - with console.status("Listing pipeline runs...\n"): - pipeline_runs = client.list_pipeline_runs(**kwargs) - except KeyError as err: - cli_utils.error(str(err)) - else: - if not pipeline_runs.items: - cli_utils.declare("No pipeline runs found for this filter.") - return - - cli_utils.print_pipeline_runs_table(pipeline_runs=pipeline_runs.items) - cli_utils.print_page_info(pipeline_runs) + with console.status("Listing pipeline runs..."): + pipeline_runs = Client().list_pipeline_runs(**kwargs) + + pipeline_run_list = [] + for pipeline_run in pipeline_runs.items: + pipeline_run_data = cli_utils.prepare_response_data(pipeline_run) + pipeline_run_data.update(_generate_pipeline_run_data(pipeline_run)) + pipeline_run_list.append(pipeline_run_data) + + cli_utils.handle_output( + pipeline_run_list, + pagination_info=pipeline_runs.pagination_info, + columns=columns, + output_format=output_format, + ) @runs.command("stop") @@ -679,36 +798,46 @@ def builds() -> None: """Commands for pipeline builds.""" +@list_options( + PipelineBuildFilter, + default_columns=[ + "id", + "pipeline_name", + "zenml_version", + "stack_name", + "created", + ], +) @builds.command("list", help="List all pipeline builds.") -@list_options(PipelineBuildFilter) -def list_pipeline_builds(**kwargs: Any) -> None: +def list_pipeline_builds( + output_format: str, columns: str, **kwargs: Any +) -> None: """List all pipeline builds for the filter. Args: - **kwargs: Keyword arguments to filter pipeline builds. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter pipeline builds. """ - client = Client() - try: - with console.status("Listing pipeline builds...\n"): - pipeline_builds = client.list_builds(hydrate=True, **kwargs) - except KeyError as err: - cli_utils.error(str(err)) - else: - if not pipeline_builds.items: - cli_utils.declare("No pipeline builds found for this filter.") - return + with console.status("Listing pipeline builds..."): + pipeline_builds = Client().list_builds(hydrate=True, **kwargs) - cli_utils.print_pydantic_models( - pipeline_builds, - exclude_columns=[ - "created", - "updated", - "user", - "project", - "images", - "stack_checksum", - ], - ) + pipeline_build_list = [] + for pipeline_build in pipeline_builds.items: + pipeline_build_data = cli_utils.prepare_response_data( + pipeline_build + ) + pipeline_build_data.update( + _generate_pipeline_build_data(pipeline_build) + ) + pipeline_build_list.append(pipeline_build_data) + + cli_utils.handle_output( + pipeline_build_list, + pagination_info=pipeline_builds.pagination_info, + columns=columns, + output_format=output_format, + ) @builds.command("delete") diff --git a/src/zenml/cli/project.py b/src/zenml/cli/project.py index b4807acbde7..f0f78cd3ac2 100644 --- a/src/zenml/cli/project.py +++ b/src/zenml/cli/project.py @@ -21,7 +21,6 @@ from zenml.cli.cli import TagGroup, cli from zenml.cli.utils import ( check_zenml_pro_project_availability, - is_sorted_or_filtered, list_options, ) from zenml.client import Client @@ -35,33 +34,44 @@ def project() -> None: """Commands for project management.""" +@list_options( + ProjectFilter, default_columns=["name", "description", "created"] +) @project.command("list") -@list_options(ProjectFilter) @click.pass_context -def list_projects(ctx: click.Context, /, **kwargs: Any) -> None: +def list_projects( + ctx: click.Context, output_format: str, columns: str, /, **kwargs: Any +) -> None: """List all projects. Args: ctx: The click context object - **kwargs: Keyword arguments to filter the list of projects. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter the list of projects. """ check_zenml_pro_project_availability() + client = Client() - with console.status("Listing projects...\n"): + with console.status("Listing projects..."): projects = client.list_projects(**kwargs) - if projects: - try: - active_project = [client.active_project] - except Exception: - active_project = [] - cli_utils.print_pydantic_models( - projects, - exclude_columns=["id", "created", "updated"], - active_models=active_project, - show_active=not is_sorted_or_filtered(ctx), - ) - else: - cli_utils.declare("No projects found for the given filter.") + + project_list = [] + for project in projects.items: + project_data = cli_utils.prepare_response_data(project) + project_data.update( + { + "description": project.description, + } + ) + project_list.append(project_data) + + cli_utils.handle_output( + project_list, + pagination_info=projects.pagination_info, + columns=columns, + output_format=output_format, + ) @project.command("register") diff --git a/src/zenml/cli/secret.py b/src/zenml/cli/secret.py index 5c5a38b74e0..f83c6585516 100644 --- a/src/zenml/cli/secret.py +++ b/src/zenml/cli/secret.py @@ -14,10 +14,11 @@ """Functionality to generate stack component CLI commands.""" import getpass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import click +from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli from zenml.cli.utils import ( confirmation, @@ -28,8 +29,6 @@ list_options, parse_name_and_extra_arguments, pretty_print_secret, - print_page_info, - print_table, validate_keys, warning, ) @@ -46,6 +45,7 @@ logger = get_logger(__name__) + @cli.group(cls=TagGroup, tag=CliCategories.IDENTITY_AND_SECURITY) def secret() -> None: """Create, list, update, or delete secrets.""" @@ -162,14 +162,18 @@ def create_secret( error(f"Centralized secrets management is disabled: {str(e)}") +@list_options( + SecretFilter, default_columns=["name", "scope", "user", "created"] +) @secret.command( "list", help="List all registered secrets that match the filter criteria." ) -@list_options(SecretFilter) -def list_secrets(**kwargs: Any) -> None: +def list_secrets(output_format: str, columns: str, **kwargs: Any) -> None: """List all secrets that fulfill the filter criteria. Args: + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. kwargs: Keyword arguments to filter the secrets. """ client = Client() @@ -178,20 +182,21 @@ def list_secrets(**kwargs: Any) -> None: secrets = client.list_secrets(**kwargs) except NotImplementedError as e: error(f"Centralized secrets management is disabled: {str(e)}") - if not secrets.items: - warning("No secrets found for the given filters.") - return - secret_rows = [ - dict( - name=secret.name, - id=str(secret.id), - private=secret.private, - ) - for secret in secrets.items - ] - print_table(secret_rows) - print_page_info(secrets) + secret_list = [] + for secret in secrets.items: + secret_data = cli_utils.prepare_response_data(secret) + secret_data.update({ + "scope": "private" if secret.private else "public", + }) + secret_list.append(secret_data) + + cli_utils.handle_output( + secret_list, + pagination_info=secrets.pagination_info, + columns=columns, + output_format=output_format, + ) @secret.command("get", help="Get a secret with a given name, prefix or id.") diff --git a/src/zenml/cli/service_accounts.py b/src/zenml/cli/service_accounts.py index d5a4ef13b09..64e493b3522 100644 --- a/src/zenml/cli/service_accounts.py +++ b/src/zenml/cli/service_accounts.py @@ -13,23 +13,26 @@ # permissions and limitations under the License. """CLI functionality to interact with API keys.""" -from typing import Any, Optional +from typing import Any, Dict, Optional import click from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import list_options +from zenml.cli.utils import ( + list_options, +) from zenml.client import Client from zenml.console import console from zenml.enums import CliCategories, StoreType from zenml.exceptions import EntityExistsError, IllegalOperationError from zenml.logger import get_logger -from zenml.models import APIKeyFilter, ServiceAccountFilter +from zenml.models import APIKeyFilter, APIKeyResponse, ServiceAccountFilter logger = get_logger(__name__) + def _create_api_key( service_account_name_or_id: str, name: str, @@ -185,31 +188,37 @@ def describe_service_account(service_account_name_or_id: str) -> None: @service_account.command("list") -@list_options(ServiceAccountFilter) +@list_options( + ServiceAccountFilter, + default_columns=["id", "name", "description", "active", "created"], +) @click.pass_context -def list_service_accounts(ctx: click.Context, /, **kwargs: Any) -> None: - """List all users. +def list_service_accounts( + ctx: click.Context, output_format: str, columns: str, /, **kwargs: Any +) -> None: + """List all service accounts. Args: ctx: The click context object - kwargs: Keyword arguments to filter the list of users. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter the list of service accounts. """ client = Client() - with console.status("Listing service accounts...\n"): + with console.status("Listing service accounts..."): service_accounts = client.list_service_accounts(**kwargs) - if not service_accounts: - cli_utils.declare( - "No service accounts found for the given filters." - ) - return - cli_utils.print_pydantic_models( - service_accounts, - exclude_columns=[ - "created", - "updated", - ], - ) + service_account_list = [] + for service_account in service_accounts.items: + service_account_data = cli_utils.prepare_response_data(service_account) + service_account_list.append(service_account_data) + + cli_utils.handle_output( + service_account_list, + pagination_info=service_accounts.pagination_info, + columns=columns, + output_format=output_format, + ) @service_account.command( @@ -381,40 +390,49 @@ def describe_api_key(service_account_name_or_id: str, name_or_id: str) -> None: ) +@list_options( + APIKeyFilter, + default_columns=["id", "name", "description", "active", "created"], +) @api_key.command("list", help="List all API keys.") -@list_options(APIKeyFilter) @click.pass_obj -def list_api_keys(service_account_name_or_id: str, /, **kwargs: Any) -> None: +def list_api_keys( + service_account_name_or_id: str, + output_format: str, + columns: str, + /, + **kwargs: Any, +) -> None: """List all API keys. Args: service_account_name_or_id: The name or ID of the service account for which to list the API keys. - **kwargs: Keyword arguments to filter API keys. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter API keys. """ - with console.status("Listing API keys...\n"): - try: - api_keys = Client().list_api_keys( - service_account_name_id_or_prefix=service_account_name_or_id, - **kwargs, - ) - except KeyError as e: - cli_utils.error(str(e)) - - if not api_keys.items: - cli_utils.declare("No API keys found for this filter.") - return - - cli_utils.print_pydantic_models( - api_keys, - exclude_columns=[ - "created", - "updated", - "key", - "retain_period_minutes", - ], + with console.status("Listing API keys..."): + api_keys = Client().list_api_keys( + service_account_name_id_or_prefix=service_account_name_or_id, + **kwargs, ) + api_key_list = [] + for api_key in api_keys.items: + api_key_data = cli_utils.prepare_response_data(api_key) + api_key_data.update({ + "active": api_key.active, + }) + api_key_list.append(api_key_data) + + cli_utils.handle_output( + api_key_list, + pagination_info=api_keys.pagination_info, + columns=columns, + output_format=output_format, + ) + @api_key.command("update", help="Update an API key.") @click.argument("name_or_id", type=str, required=True) diff --git a/src/zenml/cli/service_connectors.py b/src/zenml/cli/service_connectors.py index f0f64bcea81..2a6dc1e3029 100644 --- a/src/zenml/cli/service_connectors.py +++ b/src/zenml/cli/service_connectors.py @@ -22,9 +22,7 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli from zenml.cli.utils import ( - is_sorted_or_filtered, list_options, - print_page_info, ) from zenml.client import Client from zenml.console import console @@ -39,6 +37,152 @@ from zenml.utils.time_utils import seconds_to_human_readable, utc_now +def _get_connector_type_emoji_and_short_name( + connector_type_id: str, +) -> tuple[str, str]: + """Get emoji and short name for connector type. + + Args: + connector_type_id: The connector type identifier (e.g., 'aws', 'gcp') + + Returns: + Tuple of (emoji, short_name) + """ + connector_mappings = { + "aws": ("🔶", "aws"), # Orange diamond/rhomboid + "aws-generic": ( + "🔶", + "aws", + ), # Orange diamond/rhomboid - same as main AWS + "gcp": ("🔵", "gcp"), # Blue dot + "gcp-generic": ("🔵", "gcp"), # Blue dot - same as main GCP + "google-cloud": ("🔵", "gcp"), # Blue dot + "azure": ("🔷", "azure"), + "azure-generic": ("🔷", "azure"), # Blue diamond - same as main Azure + "kubernetes": ("🌀", "k8s"), # Blue spiral + "kubernetes-generic": ( + "🌀", + "k8s", + ), # Blue spiral - same as main Kubernetes + "docker": ("🐳", "docker"), + "docker-generic": ("🐳", "docker"), # Whale - same as main Docker + "github": ("🐙", "github"), + "gitlab": ("🦊", "gitlab"), + "hyperai": ("🚀", "hyperai"), + "slack": ("💬", "slack"), + "discord": ("🎮", "discord"), + "teams": ("👥", "teams"), + } + + return connector_mappings.get( + connector_type_id.lower(), ("🔗", connector_type_id) + ) + + +def _format_resource_types_with_emojis( + resource_types: List[str], connector_type: str = "" +) -> str: + """Format resource types with emojis, each on a separate line. + + Args: + resource_types: List of resource type identifiers + connector_type: The connector type to determine generic resource emoji + + Returns: + Formatted string with emojis and resource types, one per line + """ + resource_emojis = { + "s3-bucket": "📦", # Box + "gcs-bucket": "📦", # Box + "kubernetes-cluster": "🌀", # Blue spiral + "docker-registry": "🐳", + "azure-blob": "💾", + "ecr-registry": "📦", + "gcr-registry": "📦", + "acr-registry": "📦", + } + + # Service-specific emojis for generic resources + generic_service_emojis = { + "aws": "🔶", + "gcp": "🔵", + "azure": "🔷", + "kubernetes": "🌀", + "docker": "🐳", + "github": "🐙", + "gitlab": "🦊", + "hyperai": "🚀", + "slack": "💬", + "discord": "🎮", + } + + if not resource_types: + return "" + + # Determine default emoji for generic resources based on connector type + default_emoji = generic_service_emojis.get(connector_type.lower(), "📋") + + # Limit to first 5 resource types to avoid overly tall cells + display_types = resource_types[:5] + formatted = [] + + for rt in display_types: + emoji = resource_emojis.get(rt, default_emoji) + formatted.append(f"{emoji} {rt}") + + # Join with newlines for multi-line display + result = "\n".join(formatted) + if len(resource_types) > 5: + result += f"\n+{len(resource_types) - 5} more" + + return result + + +def _generate_service_connector_data( + service_connector: ServiceConnectorResponse, + output_format: str, +) -> Dict[str, Any]: + """Generate additional data for service connector display. + + Args: + service_connector: The service connector response. + output_format: The output format. + + Returns: + The additional data for the service connector. + """ + # Get active connector IDs from current stack for enrichment + active_stack = Client().active_stack_model + + active_connector_ids = [] + for components in active_stack.components.values(): + active_connector_ids.extend( + [ + component.connector.id + for component in components + if component.connector + ] + ) + is_active = service_connector.id in active_connector_ids + + result = {"is_active": is_active} + + # Add formatted resource types with emojis + if output_format == "table" and hasattr(service_connector, "resource_types"): + result["resource_types"] = _format_resource_types_with_emojis( + service_connector.resource_types, service_connector.type + ) + + # Add type with emoji and short name + if output_format == "table" and hasattr(service_connector, "type"): + emoji, short_name = _get_connector_type_emoji_and_short_name( + service_connector.type + ) + result["type_display"] = f"{emoji} {short_name}" + + return result + + # Service connectors @cli.group( cls=TagGroup, @@ -959,12 +1103,14 @@ def register_service_connector( ) +@list_options( + ServiceConnectorFilter, + default_columns=["id", "name", "type_display", "resource_types", "user"], +) @service_connector.command( "list", - help="""List available service connectors. -""", + help="""List available service connectors.""", ) -@list_options(ServiceConnectorFilter) @click.option( "--label", "-l", @@ -973,14 +1119,18 @@ def register_service_connector( "can be used multiple times.", multiple=True, ) -@click.pass_context def list_service_connectors( - ctx: click.Context, /, labels: Optional[List[str]] = None, **kwargs: Any + output_format: str, + columns: str, + labels: Optional[List[str]] = None, + **kwargs: Any, ) -> None: """List all service connectors. Args: ctx: The click context object + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. labels: Labels to filter by. kwargs: Keyword arguments to filter the components. """ @@ -991,17 +1141,25 @@ def list_service_connectors( labels, allow_label_only=True ) - connectors = client.list_service_connectors(**kwargs) - if not connectors: - cli_utils.declare("No service connectors found for the given filters.") - return + with console.status("Listing service connectors..."): + connectors = client.list_service_connectors(**kwargs) + + connector_list = [] + for connector in connectors.items: + connector_data = cli_utils.prepare_response_data(connector) + connector_data.update( + _generate_service_connector_data( + connector, output_format + ) + ) + connector_list.append(connector_data) - cli_utils.print_service_connectors_table( - client=client, - connectors=connectors.items, - show_active=not is_sorted_or_filtered(ctx), + cli_utils.handle_output( + connector_list, + pagination_info=connectors.pagination_info, + columns=columns, + output_format=output_format, ) - print_page_info(connectors) @service_connector.command( diff --git a/src/zenml/cli/stack.py b/src/zenml/cli/stack.py index 12d29b95fdf..fc6bf620ed9 100644 --- a/src/zenml/cli/stack.py +++ b/src/zenml/cli/stack.py @@ -43,11 +43,9 @@ from zenml.cli.text_utils import OldSchoolMarkdownHeading from zenml.cli.utils import ( _component_display_name, - is_sorted_or_filtered, list_options, + prepare_response_data, print_model_url, - print_page_info, - print_stacks_table, ) from zenml.client import Client from zenml.console import console @@ -66,6 +64,7 @@ ServiceConnectorResourcesInfo, StackFilter, StackRequest, + StackResponse, ) from zenml.models.v2.core.service_connector import ( ServiceConnectorRequest, @@ -569,7 +568,7 @@ def register_stack( connectors.add(conn_.name) for connector in connectors: delete_commands.append( - "zenml service-connector delete " + connector + f"zenml service-connector delete {connector}" ) for each in created_objects: if comps_ := created_stack.components[StackComponentType(each)]: @@ -1006,28 +1005,92 @@ def rename_stack( print_model_url(get_stack_url(stack_)) +def _generate_stack_data( + stack: StackResponse, output_format: str +) -> List[Dict[str, Any]]: + """Generate additional data for the stack to display in the output. + + Args: + stack: The stack response. + table_args: The table arguments. + + Returns: + The additional data for the stack. + """ + from zenml.enums import StackComponentType + + client = Client() + + active_stack_id = client.active_stack_model.id + is_active = stack.id == active_stack_id + + result = { + "orchestrator": "-", + "artifact_store": "-", + "is_active": is_active, + "components": len(stack.components) + if hasattr(stack, "components") + else 0, + } + + if hasattr(stack, "components") and stack.components: + if StackComponentType.ORCHESTRATOR in stack.components: + result["orchestrator"] = stack.components[ + StackComponentType.ORCHESTRATOR + ][0].name + if StackComponentType.ARTIFACT_STORE in stack.components: + result["artifact_store"] = stack.components[ + StackComponentType.ARTIFACT_STORE + ][0].name + + result["name"] = stack.name + + if is_active and output_format == "table": + result["name"] = ( + f"[green]●[/green] [bold green]{stack.name}[/bold green] (active)" + ) + return result + + @stack.command("list") -@list_options(StackFilter) -@click.pass_context -def list_stacks(ctx: click.Context, /, **kwargs: Any) -> None: +@list_options( + StackFilter, + default_columns=[ + "id", + "name", + "owner", + "components", + "orchestrator", + "artifact_store", + ], +) +def list_stacks(output_format: str, columns: str, **kwargs: Any) -> None: """List all stacks that fulfill the filter requirements. Args: - ctx: the Click context + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. kwargs: Keyword arguments to filter the stacks. """ client = Client() - with console.status("Listing stacks...\n"): + + with console.status("Listing stacks..."): stacks = client.list_stacks(**kwargs) - if not stacks: - cli_utils.declare("No stacks found for the given filters.") - return - print_stacks_table( - client=client, - stacks=stacks.items, - show_active=not is_sorted_or_filtered(ctx), - ) - print_page_info(stacks) + + stack_list = [] + for stack in stacks.items: + stack_data = prepare_response_data(stack) + stack_data.update( + _generate_stack_data(stack, output_format=output_format) + ) + stack_list.append(stack_data) + + cli_utils.handle_output( + stack_list, + pagination_info=stacks.pagination_info, + columns=columns, + output_format=output_format, + ) @stack.command( diff --git a/src/zenml/cli/stack_components.py b/src/zenml/cli/stack_components.py index 5037e4f9240..737c02582cf 100644 --- a/src/zenml/cli/stack_components.py +++ b/src/zenml/cli/stack_components.py @@ -15,7 +15,7 @@ import time from importlib import import_module -from typing import Any, Callable, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, cast from uuid import UUID import click @@ -29,10 +29,8 @@ from zenml.cli.served_model import register_model_deployer_subcommands from zenml.cli.utils import ( _component_display_name, - is_sorted_or_filtered, list_options, print_model_url, - print_page_info, ) from zenml.client import Client from zenml.console import console @@ -41,12 +39,43 @@ from zenml.io import fileio from zenml.models import ( ComponentFilter, + ComponentResponse, + FlavorFilter, + FlavorResponse, ServiceConnectorResourcesModel, ) from zenml.utils import source_utils from zenml.utils.dashboard_utils import get_component_url +def _generate_component_data(component: ComponentResponse) -> Dict[str, Any]: + """Generate additional data for component display. + + Args: + component: The component response. + + Returns: + The additional data for the component. + """ + return { + "flavor": component.flavor_name if component.flavor_name else "", + } + + +def _generate_flavor_data(flavor: FlavorResponse) -> Dict[str, Any]: + """Generate additional data for flavor display. + + Args: + flavor: The flavor response. + + Returns: + The additional data for the flavor. + """ + return { + "integration": flavor.integration if flavor.integration else "", + } + + def generate_stack_component_get_command( component_type: StackComponentType, ) -> Callable[[], None]: @@ -152,32 +181,39 @@ def generate_stack_component_list_command( A function that can be used as a `click` command. """ - @list_options(ComponentFilter) + @list_options( + ComponentFilter, + default_columns=["id", "name", "flavor", "user", "created"], + ) @click.pass_context def list_stack_components_command( - ctx: click.Context, /, **kwargs: Any + ctx: click.Context, output_format: str, columns: str, /, **kwargs: Any ) -> None: """Prints a table of stack components. Args: ctx: The click context object + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. kwargs: Keyword arguments to filter the components. """ client = Client() with console.status(f"Listing {component_type.plural}..."): kwargs["type"] = component_type components = client.list_stack_components(**kwargs) - if not components: - cli_utils.declare("No components found for the given filters.") - return - cli_utils.print_components_table( - client=client, - component_type=component_type, - components=components.items, - show_active=not is_sorted_or_filtered(ctx), - ) - print_page_info(components) + component_list = [] + for component in components.items: + component_data = cli_utils.prepare_response_data(component) + component_data.update(_generate_component_data(component)) + component_list.append(component_data) + + cli_utils.handle_output( + component_list, + pagination_info=components.pagination_info, + columns=columns, + output_format=output_format, + ) return list_stack_components_command @@ -782,15 +818,38 @@ def generate_stack_component_flavor_list_command( """ display_name = _component_display_name(component_type) - def list_stack_component_flavor_command() -> None: - """Lists the flavors for a single type of stack component.""" - client = Client() + @list_options( + FlavorFilter, + default_columns=["id", "name", "integration", "connector_type"], + ) + @click.pass_context + def list_stack_component_flavor_command( + ctx: click.Context, output_format: str, columns: str, /, **kwargs: Any + ) -> None: + """Lists the flavors for a single type of stack component. - with console.status(f"Listing {display_name} flavors`...\n"): + Args: + ctx: The click context. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: The keyword arguments. + """ + client = Client() + with console.status(f"Listing {display_name} flavors..."): flavors = client.get_flavors_by_type(component_type=component_type) - cli_utils.print_flavor_list(flavors=flavors) - cli_utils.print_page_info(flavors) + flavor_list = [] + for flavor in flavors.items: + flavor_data = cli_utils.prepare_response_data(flavor) + flavor_data.update(_generate_flavor_data(flavor)) + flavor_list.append(flavor_data) + + cli_utils.handle_output( + flavor_list, + pagination_info=flavors.pagination_info, + columns=columns, + output_format=output_format, + ) return list_stack_component_flavor_command diff --git a/src/zenml/cli/tag.py b/src/zenml/cli/tag.py index 5d81f8e393f..d1c454da071 100644 --- a/src/zenml/cli/tag.py +++ b/src/zenml/cli/tag.py @@ -20,7 +20,9 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli +from zenml.cli.utils import list_options from zenml.client import Client +from zenml.console import console from zenml.enums import CliCategories, ColorVariants from zenml.exceptions import EntityExistsError from zenml.logger import get_logger @@ -37,23 +39,29 @@ def tag() -> None: """Interact with tags.""" -@cli_utils.list_options(TagFilter) +@list_options(TagFilter, default_columns=["id", "name", "color", "created"]) @tag.command("list", help="List tags with filter.") -def list_tags(**kwargs: Any) -> None: +def list_tags(output_format: str, columns: str, **kwargs: Any) -> None: """List tags with filter. Args: - **kwargs: Keyword arguments to filter models. + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. + kwargs: Keyword arguments to filter tags. """ - tags = Client().list_tags(**kwargs) - - if not tags: - cli_utils.declare("No tags found.") - return - - cli_utils.print_pydantic_models( - tags, - exclude_columns=["created"], + with console.status("Listing tags..."): + tags = Client().list_tags(**kwargs) + + tag_list = [] + for tag in tags.items: + tag_data = cli_utils.prepare_response_data(tag) + tag_list.append(tag_data) + + cli_utils.handle_output( + tag_list, + pagination_info=tags.pagination_info, + columns=columns, + output_format=output_format, ) diff --git a/src/zenml/cli/user_management.py b/src/zenml/cli/user_management.py index 8399c38e42c..5bc1db0ac48 100644 --- a/src/zenml/cli/user_management.py +++ b/src/zenml/cli/user_management.py @@ -13,13 +13,16 @@ # permissions and limitations under the License. """Functionality to administer users of the ZenML CLI and server.""" -from typing import Any, Optional +from typing import Any, Dict, Optional +from uuid import UUID import click from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import is_sorted_or_filtered, list_options +from zenml.cli.utils import ( + list_options, +) from zenml.client import Client from zenml.config.global_config import GlobalConfiguration from zenml.console import console @@ -29,7 +32,51 @@ EntityExistsError, IllegalOperationError, ) -from zenml.models import UserFilter +from zenml.models import UserFilter, UserResponse + + +def _generate_user_data( + user: UserResponse, active_user_id: UUID, output_format: str +) -> Dict[str, Any]: + """Generate additional data for user display. + + Args: + user: The user response. + active_user_id: The ID of the active user. + output_format: The output format. + + Returns: + The additional data for the user. + """ + is_active_user = user.id == active_user_id + + # For many users, the name field contains the email and email field is null + display_name = user.name + display_email = user.email or "" + + # If name looks like an email and email is empty, separate them nicely + if "@" in user.name and not user.email: + display_name = user.full_name or user.name.split("@")[0] + display_email = user.name + + result = { + "display_name": display_name, + "display_email": display_email, + "role": "admin" if user.is_admin else "user", + "active": user.active, + } + + # If this is the active user, format the name for visual distinction in table output + if is_active_user and output_format == "table": + result["name"] = ( + f"[green]●[/green] [bold green]{display_name}[/bold green] (you)" + ) + else: + result["name"] = display_name + + result["email"] = display_email + + return result @cli.group(cls=TagGroup, tag=CliCategories.IDENTITY_AND_SECURITY) @@ -77,34 +124,37 @@ def describe_user(user_name_or_id: Optional[str] = None) -> None: @user.command("list") -@list_options(UserFilter) -@click.pass_context -def list_users(ctx: click.Context, /, **kwargs: Any) -> None: +@list_options( + UserFilter, + default_columns=["id", "name", "email", "role", "active", "created"], +) +def list_users(output_format: str, columns: str, **kwargs: Any) -> None: """List all users. Args: - ctx: The click context object + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated list of columns to display. kwargs: Keyword arguments to filter the list of users. """ client = Client() - with console.status("Listing stacks...\n"): + with console.status("Listing users..."): users = client.list_users(**kwargs) - if not users: - cli_utils.declare("No users found for the given filters.") - return - cli_utils.print_pydantic_models( - users, - exclude_columns=[ - "created", - "updated", - "email", - "email_opted_in", - "activation_token", - ], - active_models=[Client().active_user], - show_active=not is_sorted_or_filtered(ctx), + active_user_id = client.active_user.id + user_list = [] + for user in users.items: + user_data = cli_utils.prepare_response_data(user) + user_data.update( + _generate_user_data(user, active_user_id, output_format) ) + user_list.append(user_data) + + cli_utils.handle_output( + user_list, + pagination_info=users.pagination_info, + columns=columns, + output_format=output_format, + ) @user.command( diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 5ab3d6f2bcf..274a4ed554b 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -14,7 +14,6 @@ """Utility functions for the CLI.""" import contextlib -import functools import json import os import platform @@ -39,7 +38,6 @@ Type, TypeVar, Union, - cast, ) import click @@ -57,8 +55,10 @@ from zenml.client import Client from zenml.console import console, zenml_style_defaults from zenml.constants import ( + ENV_ZENML_CLI_COLUMN_WIDTH, FILTERING_DATETIME_FORMAT, IS_DEBUG_ENV, + handle_int_env_var, ) from zenml.enums import GenericFilterOps, ServiceState, StackComponentType from zenml.logger import get_logger @@ -74,6 +74,7 @@ Page, ServiceConnectorRequirements, StrFilter, + UserResponse, UUIDFilter, ) from zenml.models.v2.base.filter import FilterGenerator @@ -81,7 +82,7 @@ from zenml.stack import StackComponent from zenml.stack.flavor import Flavor from zenml.stack.stack_component import StackComponentConfig -from zenml.utils import dict_utils, secret_utils +from zenml.utils import secret_utils from zenml.utils.package_utils import requirement_installed from zenml.utils.time_utils import expires_in from zenml.utils.typing_utils import get_origin, is_union @@ -110,6 +111,8 @@ logger = get_logger(__name__) +AnyResponse = TypeVar("AnyResponse", bound=BaseIdentifiedResponse) # type: ignore[type-arg] + MAX_ARGUMENT_VALUE_SIZE = 10240 @@ -297,8 +300,9 @@ def print_table( value = escape(value) values.append(value) rich_table.add_row(*values) - if len(rich_table.columns) > 1: - rich_table.columns[0].justify = "center" + # Consistent left alignment for all columns + for column in rich_table.columns: + column.justify = "left" console.print(rich_table) @@ -440,7 +444,7 @@ def __dictify(model: T) -> Dict[str, str]: ] print_table([__dictify(model) for model in table_items]) - print_page_info(models) + print_page_info(models.pagination_info) else: table_items = list(models) @@ -1498,55 +1502,6 @@ def replace_emojis(text: str) -> str: return text -def print_stacks_table( - client: "Client", - stacks: Sequence["StackResponse"], - show_active: bool = False, -) -> None: - """Print a prettified list of all stacks supplied to this method. - - Args: - client: Repository instance - stacks: List of stacks - show_active: Flag to decide whether to append the active stack on the - top of the list. - """ - stack_dicts = [] - - stacks = list(stacks) - active_stack = client.active_stack_model - if show_active: - if active_stack.id not in [s.id for s in stacks]: - stacks.append(active_stack) - - stacks = [s for s in stacks if s.id == active_stack.id] + [ - s for s in stacks if s.id != active_stack.id - ] - - active_stack_model_id = client.active_stack_model.id - for stack in stacks: - is_active = stack.id == active_stack_model_id - - if stack.user: - user_name = stack.user.name - else: - user_name = "-" - - stack_config = { - "ACTIVE": ":point_right:" if is_active else "", - "STACK NAME": stack.name, - "STACK ID": stack.id, - "OWNER": user_name, - **{ - component_type.upper(): components[0].name - for component_type, components in stack.components.items() - }, - } - stack_dicts.append(stack_config) - - print_table(stack_dicts) - - def print_components_table( client: "Client", component_type: StackComponentType, @@ -2315,15 +2270,15 @@ def check_zenml_pro_project_availability() -> None: ) -def print_page_info(page: Page[T]) -> None: +def print_page_info(pagination_info: Dict[str, Any]) -> None: """Print all page information showing the number of items and pages. Args: - page: The page to print the information for. + pagination_info: The pagination information to print. """ declare( - f"Page `({page.index}/{page.total_pages})`, `{page.total}` items " - f"found for the applied filters." + f"Page `({pagination_info['index']}/{pagination_info['total_pages']})`, " + f"`{pagination_info['total']}` items found for the applied filters." ) @@ -2440,7 +2395,9 @@ def _is_list_field(field_info: Any) -> bool: ) -def list_options(filter_model: Type[BaseFilter]) -> Callable[[F], F]: +def list_options( + filter_model: Type[BaseFilter], default_columns: Optional[List[str]] = None +) -> Callable[[F], F]: """Create a decorator to generate the correct list of filter parameters. The Outer decorator (`list_options`) is responsible for creating the inner @@ -2452,6 +2409,8 @@ def list_options(filter_model: Type[BaseFilter]) -> Callable[[F], F]: Args: filter_model: The filter model based on which to decorate the function. + default_columns: Optional list of column names to use as defaults when + --columns is not specified and output format is table. Returns: The inner decorator. @@ -2477,6 +2436,33 @@ def inner_decorator(func: F) -> F: create_data_type_help_text(filter_model, k) ) + # Add columns and output options + options.extend( + [ + click.option( + "--columns", + type=str, + default=",".join(default_columns) + if default_columns + else "", + help="Comma-separated list of columns to display.", + ), + click.option( + "--output", + "-o", + "output_format", + type=click.Choice(["table", "json", "yaml", "tsv", "csv"]), + default=get_default_output_format(), + help="Output format for the list.", + ), + ] + ) + + def wrapper(function: F) -> F: + for option in reversed(options): + function = option(function) + return function + func.__doc__ = ( f"{func.__doc__} By default all filters are " f"interpreted as a check for equality. However advanced " @@ -2497,17 +2483,7 @@ def inner_decorator(func: F) -> F: f"{joined_data_type_descriptors}" ) - for option in reversed(options): - func = option(func) - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal func - - kwargs = dict_utils.remove_none_values(kwargs) - return func(*args, **kwargs) - - return cast(F, wrapper) + return wrapper(func) return inner_decorator @@ -2715,3 +2691,396 @@ def requires_mac_env_var_warning() -> bool: "OBJC_DISABLE_INITIALIZE_FORK_SAFETY" ) and mac_version_tuple >= (10, 13) return False + + +def get_default_output_format() -> str: + """Get the default output format from environment variable. + + Returns: + The default output format, falling back to "table" if not configured. + """ + from zenml.constants import ENV_ZENML_DEFAULT_OUTPUT + + return os.environ.get(ENV_ZENML_DEFAULT_OUTPUT, "table") + + +def prepare_response_data(item: AnyResponse) -> Dict[str, Any]: + """Prepare data from BaseResponse instances. + + Args: + item: BaseResponse instance to format + + Returns: + Dictionary with the data + """ + item_data = {"id": item.id} + + if hasattr(item, "name"): + item_data["name"] = getattr(item, "name") + + if item.body is not None: + body_data = item.body.model_dump(mode="json") + item_data.update(body_data) + + if item.resources is not None: + if user := getattr(item.resources, "user", None): + if isinstance(user, UserResponse): + item_data["user"] = user.name + + return item_data + + +def handle_output( + data: List[Dict[str, Any]], + pagination_info: Dict[str, Any], + columns: List[str], + output_format: str, +) -> None: + """Handle output formatting for CLI commands. + + This function processes the output formatting parameters from CLI options + and calls the appropriate rendering function. + + Args: + data: List of dictionaries to render + pagination_info: Info about the pagination + output_format: Optional output format (table, json, yaml, tsv, csv). + columns: Optional comma-separated column names. If None and + default_columns provided, uses default_columns for table output. + """ + cli_output = prepare_output( + data=data, + output_format=output_format, + columns=columns, + pagination=pagination_info, + ) + if cli_output: + from zenml_cli import clean_output + + clean_output(cli_output) + + if pagination_info: + print_page_info(pagination_info) + + +def prepare_output( + data: List[Dict[str, Any]], + output_format: str = "table", + columns: Optional[List[str]] = None, + pagination: Optional[Dict[str, Any]] = None, +) -> Optional[str]: + """Render data in specified format following ZenML CLI table guidelines. + + This function provides a centralized way to render tabular data across + all ZenML CLI commands with consistent formatting and multiple output + formats. + + Args: + data: List of dictionaries to render + output_format: Output format (table, json, yaml, tsv, none) + columns: Optional list of column names to include + sort_by: Column to sort by + reverse: Whether to reverse sort order + no_truncate: Whether to disable truncation + no_color: Whether to disable colored output + max_width: Maximum table width (default: use terminal width) + pagination: Optional pagination metadata for JSON/YAML output + **kwargs: Additional formatting options + + Returns: + The rendered table in the specified format or None if no data is provided + + Raises: + ValueError: If an unsupported output format is provided + """ + selected_columns = columns.split(",") + filtered_data = [] + for entry in data: + filtered_data.append( + {k: entry[k] for k in selected_columns if k in entry} + ) + + if output_format == "json": + return _render_json(filtered_data, pagination=pagination) + elif output_format == "yaml": + return _render_yaml(filtered_data, pagination=pagination) + elif output_format == "tsv": + return _render_tsv(filtered_data) + elif output_format == "csv": + return _render_csv(filtered_data) + elif output_format == "table": + return _render_table(filtered_data) + else: + raise ValueError(f"Unsupported output format: {output_format}") + + +def _render_json( + data: List[Dict[str, Any]], + pagination: Optional[Dict[str, Any]] = None, +) -> str: + """Render data as JSON. + + Args: + data: List of data dictionaries to render + pagination: Optional pagination metadata + + Returns: + JSON string representation of the data + """ + output = {"items": data} + + if pagination: + output["pagination"] = pagination + + return json.dumps(output, indent=2, default=str) + + +def _render_yaml( + data: List[Dict[str, Any]], + pagination: Optional[Dict[str, Any]] = None, +) -> str: + """Render data as YAML. + + Args: + data: List of data dictionaries to render + pagination: Optional pagination metadata + + Returns: + YAML string representation of the data + """ + output = {"items": data} + + if pagination: + output["pagination"] = pagination + + return yaml.dump(output, default_flow_style=False) + + +def _render_tsv( + data: List[Dict[str, Any]], +) -> str: + """Render data as TSV (Tab-Separated Values). + + Args: + data: List of data dictionaries to render + + Returns: + TSV string representation of the data + """ + if not data: + return "" + + headers = list(data[0].keys()) + + lines = [] + lines.append("\t".join(headers)) + + for row in data: + values = [] + for header in headers: + value = str(row.get(header, "")) + value = ( + value.replace("\t", " ").replace("\n", " ").replace("\r", " ") + ) + values.append(value) + lines.append("\t".join(values)) + + return "\n".join(lines) + + +def _render_csv( + data: List[Dict[str, Any]], +) -> str: + """Render data as CSV (Comma-Separated Values). + + Args: + data: List of data dictionaries to render + + Returns: + CSV string representation of the data + """ + headers = list(data[0].keys()) + + lines = [] + + lines.append(",".join(headers)) + + for row in data: + values = [] + for header in headers: + value = ( + str(row.get(header, "")) if row.get(header) is not None else "" + ) + if "," in value or '"' in value: + escaped_value = value.replace('"', '""') + value = f'"{escaped_value}"' + values.append(value) + lines.append(",".join(values)) + + return "\n".join(lines) + + +def _get_terminal_width() -> Optional[int]: + """Get terminal width from ZENML_CLI_COLUMN_WIDTH environment variable or shutil. + + Checks the ZENML_CLI_COLUMN_WIDTH environment variable first, then falls back + to shutil.get_terminal_size() for automatic detection. + + Returns: + Terminal width in characters, or None if cannot be determined + """ + # Check ZenML-specific CLI column width environment variable first + # Use handle_int_env_var with default=0 to indicate "not set" + columns_env = handle_int_env_var(ENV_ZENML_CLI_COLUMN_WIDTH, default=0) + if columns_env > 0: + return columns_env + + # Fall back to shutil.get_terminal_size + try: + size = shutil.get_terminal_size() + # Use a reasonable minimum width even if terminal reports smaller + return max(size.columns, 100) + except (AttributeError, OSError): + # Default to a reasonable width if we can't detect terminal size + return 120 + + +def _render_table( + data: List[Dict[str, Any]], +) -> str: + """Render data as a formatted table following ZenML guidelines. + + Args: + data: List of data dictionaries to render + + Returns: + Formatted table string representation of the data + """ + headers = list(data[0].keys()) + + # Get terminal width using robust detection + terminal_width = _get_terminal_width() + console_width = ( + max(80, min(terminal_width, 200)) if terminal_width else 150 + ) + + column_widths = {} + for header in headers: + content_lengths = [len(str(row.get(header, ""))) for row in data] + header_length = len(header) + optimal_width = max( + header_length, max(content_lengths) if content_lengths else 0 + ) + column_widths[header] = min(50, max(8, optimal_width + 2)) + + available_width = console_width - (len(headers) * 3) + total_content_width = sum(column_widths.values()) + + if total_content_width > available_width: + scale_factor = available_width / total_content_width + for header in headers: + column_widths[header] = max( + 6, int(column_widths[header] * scale_factor) + ) + + rich_table = Table( + box=box.SIMPLE_HEAD, + show_header=True, + show_lines=False, + pad_edge=False, + collapse_padding=False, + expand=True, + width=console_width, + ) + + for header in headers: + # Clean header name: replace underscores with spaces and uppercase + header_display = header.replace("_", " ").upper() + + # Smart overflow strategy based on column type + if "id" in header.lower(): + overflow = "fold" + no_wrap = True + elif "description" in header.lower(): + overflow = "fold" + no_wrap = False + else: + overflow = "ellipsis" + no_wrap = True # Keep single line, truncate with ... if needed + + rich_table.add_column( + header_display, + justify="left", + overflow=overflow, + no_wrap=no_wrap, + min_width=6, + max_width=column_widths[header], + ) + + for row in data: + values = [] + for header in headers: + value = str(row.get(header, "")) + + if not os.getenv("NO_COLOR"): + value = _colorize_value(header, value) + + values.append(value) + + rich_table.add_row(*values) + + from io import StringIO + + output_buffer = StringIO() + + table_console = Console( + width=console_width, + force_terminal=not os.getenv("NO_COLOR"), + no_color=os.getenv("NO_COLOR") is not None, + file=output_buffer, + ) + table_console.print(rich_table) + + return output_buffer.getvalue() + + +def _colorize_value(column: str, value: str) -> str: + """Apply colorization to values based on column type and content. + + Args: + column: Column name to determine colorization rules + value: Value to potentially colorize + + Returns: + Potentially colorized value with Rich markup + """ + # Status-like columns get color coding + if any( + keyword in column.lower() for keyword in ["status", "state", "health"] + ): + value_lower = value.lower() + if value_lower in [ + "active", + "healthy", + "succeeded", + "completed", + ]: + return f"[green]{value}[/green]" + elif value_lower in [ + "running", + "pending", + "initializing", + "starting", + "warning", + ]: + return f"[yellow]{value}[/yellow]" + elif value_lower in [ + "failed", + "error", + "unhealthy", + "stopped", + "crashed", + ]: + return f"[red]{value}[/red]" + + return value diff --git a/src/zenml/client.py b/src/zenml/client.py index fb1301e416f..3005fc3c265 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -7582,6 +7582,7 @@ def list_service_accounts( size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, id: Optional[Union[UUID, str]] = None, + external_user_id: Optional[Union[UUID, str]] = None, created: Optional[Union[datetime, str]] = None, updated: Optional[Union[datetime, str]] = None, name: Optional[str] = None, @@ -7597,6 +7598,7 @@ def list_service_accounts( size: The maximum size of all pages logical_operator: Which logical operator to use [and, or] id: Use the id of stacks to filter by. + external_user_id: Use the external user id for filtering. created: Use to filter by time of creation updated: Use the last updated date for filtering name: Use the service account name for filtering @@ -7615,6 +7617,7 @@ def list_service_accounts( size=size, logical_operator=logical_operator, id=id, + external_user_id=external_user_id, created=created, updated=updated, name=name, diff --git a/src/zenml/constants.py b/src/zenml/constants.py index af0007763e6..8746ac14bd3 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -188,6 +188,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES = ( "ZENML_CODE_REPOSITORY_IGNORE_UNTRACKED_FILES" ) +ENV_ZENML_DEFAULT_OUTPUT = "ZENML_DEFAULT_OUTPUT" +ENV_ZENML_CLI_COLUMN_WIDTH = "ZENML_CLI_COLUMN_WIDTH" + # Environment variable that indicates whether the current environment is running # a step operator. ENV_ZENML_STEP_OPERATOR = "ZENML_STEP_OPERATOR" diff --git a/src/zenml/models/v2/base/page.py b/src/zenml/models/v2/base/page.py index 35f136d6fec..54072cbfa49 100644 --- a/src/zenml/models/v2/base/page.py +++ b/src/zenml/models/v2/base/page.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Page model definitions.""" -from typing import Generator, Generic, List, TypeVar +from typing import Any, Dict, Generator, Generic, List, TypeVar from pydantic import BaseModel from pydantic.types import NonNegativeInt, PositiveInt @@ -43,6 +43,20 @@ def size(self) -> int: """ return len(self.items) + @property + def pagination_info(self) -> Dict[str, Any]: + """Return the pagination info. + + Returns: + The pagination info. + """ + return { + "index": self.index, + "max_size": self.max_size, + "total_pages": self.total_pages, + "total": self.total, + } + def __len__(self) -> int: """Return the item count of the page. diff --git a/src/zenml_cli/__init__.py b/src/zenml_cli/__init__.py new file mode 100644 index 00000000000..f9653cadc89 --- /dev/null +++ b/src/zenml_cli/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Core CLI functionality.""" + +import sys +import logging + +from typing import List + +# Global variable to store original stdout for CLI clean output +_original_stdout = sys.stdout + + +def reroute_stdout() -> None: + """Reroute logging to stderr for CLI commands. + + This function redirects sys.stdout to sys.stderr so that all logging + output goes to stderr, while preserving the original stdout for clean + output that can be piped. + """ + modified_handlers: List[logging.StreamHandler] = [] + + # Reroute stdout to stderr + sys.stdout = sys.stderr + + # Handle existing root logger handlers that hold references to original stdout + for handler in logging.root.handlers: + if ( + isinstance(handler, logging.StreamHandler) + and handler.stream is _original_stdout + ): + handler.stream = sys.stderr + modified_handlers.append(handler) + + # Handle ALL existing individual logger handlers that hold references to original stdout + for _, logger in logging.Logger.manager.loggerDict.items(): + if isinstance(logger, logging.Logger): + for handler in logger.handlers: + if ( + isinstance(handler, logging.StreamHandler) + and handler.stream is _original_stdout + ): + handler.setStream(sys.stderr) + modified_handlers.append(handler) + + + +def clean_output(text: str) -> None: + """Output text to stdout for clean piping, bypassing stderr rerouting. + + This function ensures that specific output goes to the original stdout + even when the CLI has rerouted stdout to stderr. This is useful for + outputting data that should be pipeable (like JSON, CSV, YAML) while + keeping logs and status messages in stderr. + + Args: + text: Text to output to stdout. + """ + _original_stdout.write(text) + if not text.endswith("\n"): + _original_stdout.write("\n") + _original_stdout.flush() + +reroute_stdout() + +# Import the cli only after rerouting stdout +from zenml.cli.cli import cli diff --git a/tests/integration/functional/cli/test_cli_tables.py b/tests/integration/functional/cli/test_cli_tables.py new file mode 100644 index 00000000000..8810bb85e3b --- /dev/null +++ b/tests/integration/functional/cli/test_cli_tables.py @@ -0,0 +1,320 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Integration tests for CLI table formatting functionality.""" + +import json +import os +import subprocess +from unittest.mock import patch + +import pytest +import yaml + +from zenml.constants import ENV_ZENML_CLI_COLUMN_WIDTH + + +class TestCLITableIntegration: + """Integration tests for CLI commands using the new table system.""" + + def run_zenml_cli(self, args, **kwargs): + """Run zenml CLI command using subprocess for realistic testing. + + Args: + args: List of command arguments (e.g., ["stack", "list"]) + **kwargs: Additional subprocess.run arguments + + Returns: + subprocess.CompletedProcess: The result of the command + """ + cmd = ["zenml"] + args + + # Set default subprocess arguments + subprocess_kwargs = { + "capture_output": True, + "text": True, + "timeout": 30, # Prevent hanging tests + } + subprocess_kwargs.update(kwargs) + + return subprocess.run(cmd, **subprocess_kwargs) + + def test_stack_list_table_format(self): + """Test stack list command with table format.""" + result = self.run_zenml_cli(["stack", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stdout + assert "NAME" in output # Check uppercase headers + assert "OWNER" in output + assert "COMPONENTS" in output + + # Check for active stack indicator + if "●" in output: + assert "(active)" in output + + def test_stack_list_json_format(self): + """Test stack list command with JSON format.""" + result = self.run_zenml_cli(["stack", "list", "--output", "json"]) + + assert result.returncode == 0 + + # JSON output should go to stdout via clean_output(), fallback to stderr + output = result.stdout if result.stdout.strip() else result.stderr + + # Parse JSON output + try: + data = json.loads(output) + assert "items" in data or isinstance(data, list) + + # If pagination format + if isinstance(data, dict) and "items" in data: + assert "pagination" in data + items = data["items"] + else: + items = data + + # Check data structure + if items: + assert "name" in items[0] + # Should not contain internal fields + assert "__is_active__" not in items[0] + + except json.JSONDecodeError: + pytest.fail("Invalid JSON output from stack list command") + + def test_stack_list_yaml_format(self): + """Test stack list command with YAML format.""" + result = self.run_zenml_cli(["stack", "list", "--output", "yaml"]) + + assert result.returncode == 0 + + # YAML output should go to stdout via clean_output(), fallback to stderr + output = result.stdout if result.stdout.strip() else result.stderr + + # Parse YAML output + try: + data = yaml.safe_load(output) + assert isinstance(data, (list, dict)) + + # If pagination format + if isinstance(data, dict) and "items" in data: + items = data["items"] + else: + items = data + + # Check data structure + if items and isinstance(items, list): + assert "name" in items[0] + + except yaml.YAMLError: + pytest.fail("Invalid YAML output from stack list command") + + def test_stack_list_tsv_format(self): + """Test stack list command with TSV format.""" + result = self.run_zenml_cli(["stack", "list", "--output", "tsv"]) + + assert result.returncode == 0 + + # TSV output should go to stdout via clean_output(), fallback to stderr + output = result.stdout if result.stdout.strip() else result.stderr + lines = output.strip().split("\n") + if lines and lines[0]: + # Check header line contains tab-separated values + assert "\t" in lines[0] + headers = lines[0].split("\t") + assert "name" in headers + + def test_user_list_table_format(self): + """Test user list command with table format.""" + result = self.run_zenml_cli(["user", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stderr + # Should contain uppercase headers and data + assert output.strip() # Should produce some output + + def test_pipeline_list_table_format(self): + """Test pipeline list command with table format.""" + result = self.run_zenml_cli(["pipeline", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stderr + # Check for reasonable output (may be empty if no pipelines) + if "NAME" in output: + assert "TAGS" in output or "DESCRIPTION" in output + + def test_model_list_table_format(self): + """Test model list command with table format.""" + result = self.run_zenml_cli(["model", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stderr + # Should handle empty or populated model lists + assert isinstance(output, str) + + def test_secret_list_table_format(self): + """Test secret list command with table format.""" + result = self.run_zenml_cli(["secret", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stderr + # Should handle empty or populated secret lists + assert isinstance(output, str) + + @patch.dict(os.environ, {"NO_COLOR": "1"}) + def test_no_color_environment(self): + """Test that NO_COLOR environment variable is respected.""" + result = self.run_zenml_cli(["stack", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stderr + # Output should be present but without ANSI escape codes + # This is a basic check - detailed ANSI parsing would be complex + assert output.strip() + + @patch.dict(os.environ, {ENV_ZENML_CLI_COLUMN_WIDTH: "40"}) + def test_narrow_terminal(self): + """Test table formatting with narrow terminal.""" + result = self.run_zenml_cli(["stack", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stdout + assert output.strip() + + # Check that lines don't exceed reasonable width for narrow terminal + lines = output.split("\n") + for line in lines: + # Remove ANSI escape codes for length check + clean_line = self._remove_ansi_codes(line) + # Allow some flexibility for table borders and formatting + assert len(clean_line) <= 120 # Reasonable upper bound + + @patch.dict(os.environ, {ENV_ZENML_CLI_COLUMN_WIDTH: "200"}) + def test_wide_terminal(self): + """Test table formatting with wide terminal.""" + result = self.run_zenml_cli(["stack", "list"]) + + assert result.returncode == 0 + # Table output goes to stderr due to stdout rerouting + output = result.stdout + assert output.strip() + + def test_pagination_in_json_output(self): + """Test that pagination information is included in JSON output.""" + result = self.run_zenml_cli(["stack", "list", "--output", "json"]) + + assert result.returncode == 0 + + # JSON output should go to stdout via clean_output(), fallback to stderr + output = result.stdout if result.stdout.strip() else result.stderr + + try: + print(output) + data = json.loads(output) + # Check if pagination format is used + if isinstance(data, dict) and "pagination" in data: + pagination = data["pagination"] + assert "index" in pagination or "total" in pagination + except json.JSONDecodeError: + pytest.fail("Invalid JSON output") + + def test_error_handling_invalid_output_format(self): + """Test error handling for invalid output format.""" + result = self.run_zenml_cli(["stack", "list", "--output", "invalid"]) + + # Should either fail gracefully or show help + output = result.stderr + result.stdout + assert result.returncode != 0 or "invalid" not in output.lower() + + def test_mixed_data_types_handling(self): + """Test handling of mixed data types in JSON output.""" + result = self.run_zenml_cli(["stack", "list", "--output", "json"]) + + assert result.returncode == 0 + + # JSON output should go to stdout via clean_output(), fallback to stderr + output = result.stdout if result.stdout.strip() else result.stderr + + try: + data = json.loads(output) + # Should be valid JSON with proper data types + assert isinstance(data, (list, dict)) + except json.JSONDecodeError: + pytest.fail("JSON output contains invalid data types") + + def _remove_ansi_codes(self, text: str) -> str: + """Remove ANSI escape codes from text for length measurement.""" + import re + + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + def test_consistent_header_formatting(self): + """Test that headers are consistently formatted across commands.""" + commands_to_test = [ + ["stack", "list"], + ["user", "list"], + # Add more as needed, but avoid commands that might not work in test env + ] + + for command in commands_to_test: + result = self.run_zenml_cli(command) + + # Table output goes to stderr due to stdout rerouting + output = result.stderr + if result.returncode == 0 and output.strip(): + # Headers should be uppercase and properly formatted + lines = output.split("\n") + header_line = None + + for line in lines: + # Look for a line that looks like headers (contains uppercase letters) + if any( + c.isupper() for c in line + ) and not line.strip().startswith("Page"): + header_line = line + break + + if header_line: + # Should contain uppercase headers + assert any(word.isupper() for word in header_line.split()) + + def test_status_colorization_in_output(self): + """Test that status values are properly colorized when color is enabled.""" + result = self.run_zenml_cli(["stack", "list"]) + + assert result.returncode == 0 + + # Table output goes to stderr due to stdout rerouting + output = result.stderr + + # If there are status-like fields with known values, they should be colorized + # This is a basic check - the exact colorization depends on the data + if any( + status in output.lower() + for status in ["active", "running", "failed", "pending"] + ): + # Should contain ANSI color codes (unless NO_COLOR is set) + if os.getenv("NO_COLOR") != "1": + # Basic check for ANSI codes presence + assert ( + "\x1b[" in output or "[32m" in output or "[31m" in output + ) diff --git a/tests/unit/utils/test_table_utils.py b/tests/unit/utils/test_table_utils.py new file mode 100644 index 00000000000..b6e91039f8b --- /dev/null +++ b/tests/unit/utils/test_table_utils.py @@ -0,0 +1,238 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for table utilities.""" + +import json +import os +from unittest.mock import patch + +import pytest +import yaml + +from zenml.utils.table_utils import zenml_table + + +class TestZenmlTable: + """Test cases for the zenml_table function.""" + + @pytest.fixture + def sample_data(self): + """Sample data for testing.""" + return [ + {"name": "test1", "status": "active", "count": 5}, + {"name": "test2", "status": "pending", "count": 3}, + {"name": "test3", "status": "failed", "count": 1}, + ] + + @pytest.fixture + def stack_data(self): + """Sample stack data with active indicator.""" + return [ + { + "name": "default", + "owner": "admin", + "components": 2, + "__is_active__": True, + }, + { + "name": "production", + "owner": "admin", + "components": 5, + "__is_active__": False, + }, + ] + + def test_table_format_basic(self, sample_data): + """Test basic table formatting.""" + output = zenml_table(sample_data, output_format="table") + assert output is not None + + # Check that headers are uppercase + assert "NAME" in output + assert "STATUS" in output + assert "COUNT" in output + + # Check data is present + assert "test1" in output + assert "active" in output + + def test_json_format(self, sample_data): + """Test JSON output format.""" + output = zenml_table(sample_data, output_format="json") + assert output is not None + + # Parse the JSON output + output_data = json.loads(output) + assert len(output_data) == 3 + assert output_data[0]["name"] == "test1" + assert output_data[0]["status"] == "active" + + def test_json_format_with_pagination(self, sample_data): + """Test JSON output with pagination metadata.""" + pagination = {"index": 1, "total": 3, "max_size": 20} + output = zenml_table( + sample_data, output_format="json", pagination=pagination + ) + assert output is not None + + output_data = json.loads(output) + assert "items" in output_data + assert "pagination" in output_data + assert len(output_data["items"]) == 3 + assert output_data["pagination"]["total"] == 3 + + def test_yaml_format(self, sample_data): + """Test YAML output format.""" + output = zenml_table(sample_data, output_format="yaml") + assert output is not None + + # Parse the YAML output + output_data = yaml.safe_load(output) + assert len(output_data) == 3 + assert output_data[0]["name"] == "test1" + + def test_tsv_format(self, sample_data): + """Test TSV output format.""" + output = zenml_table(sample_data, output_format="tsv") + assert output is not None + + lines = output.strip().split("\n") + # Check header line + assert lines[0] == "name\tstatus\tcount" + # Check data line + assert lines[1] == "test1\tactive\t5" + + def test_column_filtering(self, sample_data): + """Test filtering specific columns.""" + output = zenml_table( + sample_data, output_format="json", columns=["name", "status"] + ) + assert output is not None + + output_data = json.loads(output) + for item in output_data: + assert "name" in item + assert "status" in item + assert "count" not in item + + def test_sorting(self, sample_data): + """Test sorting by column.""" + output = zenml_table(sample_data, output_format="json", sort_by="name") + assert output is not None + + output_data = json.loads(output) + names = [item["name"] for item in output_data] + assert names == ["test1", "test2", "test3"] + + def test_reverse_sorting(self, sample_data): + """Test reverse sorting.""" + output = zenml_table( + sample_data, output_format="json", sort_by="name", reverse=True + ) + assert output is not None + + output_data = json.loads(output) + names = [item["name"] for item in output_data] + assert names == ["test3", "test2", "test1"] + + def test_empty_data(self): + """Test handling of empty data.""" + output = zenml_table([], output_format="table") + assert output == "" or output is None + + def test_invalid_output_format(self, sample_data): + """Test invalid output format raises error.""" + with pytest.raises(ValueError, match="Unsupported output format"): + zenml_table(sample_data, output_format="invalid") + + def test_stack_formatting_json_clean(self, stack_data): + """Test that JSON output removes internal fields.""" + output = zenml_table(stack_data, output_format="json") + assert output is not None + + output_data = json.loads(output) + for item in output_data: + assert "__is_active__" not in item + + def test_status_colorization(self): + """Test status value colorization.""" + data = [ + {"status": "running"}, + {"status": "failed"}, + {"status": "pending"}, + ] + output = zenml_table(data, output_format="table", no_color=False) + assert output is not None + + # Colors should be applied (exact formatting depends on Rich implementation) + assert "running" in output + assert "failed" in output + assert "pending" in output + + @patch.dict(os.environ, {"NO_COLOR": "1"}) + def test_no_color_environment(self, sample_data): + """Test NO_COLOR environment variable is respected.""" + output = zenml_table(sample_data, output_format="table") + assert output is not None + + @patch("zenml.utils.table_utils.shutil.get_terminal_size") + def test_terminal_width_detection(self, mock_terminal_size, sample_data): + """Test terminal width detection.""" + mock_terminal_size.return_value.columns = 120 + + output = zenml_table(sample_data, output_format="table", max_width=100) + assert output is not None # Should produce output + mock_terminal_size.assert_called_once() + + def test_tsv_escaping(self): + """Test TSV format properly escapes special characters.""" + data = [{"field": "value\twith\ttabs\nand\nnewlines"}] + output = zenml_table(data, output_format="tsv") + assert output is not None + + lines = output.strip().split("\n") + assert len(lines) >= 2 # Should have header and data lines + + # Check that tabs and newlines are properly escaped in the data + data_line = lines[1] + fields = data_line.split("\t") + if len(fields) > 1: + # Original tabs should be replaced with spaces + assert ( + "value with tabs and newlines" in fields[1] + or "value\twith" not in fields[1] + ) + + def test_none_values_handling(self): + """Test handling of None values in data.""" + data = [{"name": "test", "value": None, "other": "data"}] + output = zenml_table(data, output_format="json") + assert output is not None + + output_data = json.loads(output) + assert output_data[0]["value"] is None + + def test_mixed_data_types(self): + """Test handling of mixed data types.""" + data = [ + {"name": "test", "count": 42, "active": True, "rate": 3.14}, + {"name": "other", "count": 0, "active": False, "rate": 2.71}, + ] + output = zenml_table(data, output_format="json") + assert output is not None + + output_data = json.loads(output) + assert output_data[0]["count"] == 42 + assert output_data[0]["active"] is True + assert output_data[0]["rate"] == 3.14