diff --git a/CLAUDE.md b/CLAUDE.md index a43fffbbb10..1297d6d2838 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -124,6 +124,58 @@ Use filesystem navigation tools to explore the codebase structure as needed. - Add appropriate error handling - Document public APIs thoroughly +### Field Description Standards +When adding or modifying Field descriptions in stack component configs: + +#### Template Structure +``` +{Purpose statement}. {Valid values/format}. {Example(s)}. {Additional context if needed}. +``` + +#### Core Requirements +1. **Purpose**: Clearly state what the field controls or does +2. **Format**: Specify expected value format (URL, path, enum, etc.) +3. **Examples**: Provide at least one concrete example +4. **Constraints**: Include any limitations or requirements + +#### Quality Standards +- Minimum 30 characters +- Use action words (controls, configures, specifies, determines) +- Include concrete examples with realistic values +- Avoid vague language ("thing", "stuff", "value", "setting") +- Don't start with "The" or end with periods +- Be specific about valid formats and constraints + +#### Example Field Descriptions +```python +# Good examples: +instance_type: Optional[str] = Field( + None, + description="AWS EC2 instance type for step execution. Must be a valid " + "SageMaker-supported instance type. Examples: 'ml.t3.medium' (2 vCPU, 4GB RAM), " + "'ml.m5.xlarge' (4 vCPU, 16GB RAM). Defaults to ml.m5.xlarge for training steps" +) + +path: str = Field( + description="Root path for artifact storage. Must be a valid URI supported by the " + "artifact store implementation. Examples: 's3://my-bucket/artifacts', " + "'/local/storage/path', 'gs://bucket-name/zenml-artifacts'. Path must be accessible " + "with configured credentials" +) + +synchronous: bool = Field( + True, + description="Controls whether pipeline execution blocks the client. If True, " + "the client waits until all steps complete. If False, returns immediately and " + "executes asynchronously. Useful for long-running production pipelines" +) +``` + +#### Validation +- Run `python scripts/validate_descriptions.py` to check description quality +- All descriptions must pass validation before merging +- Add validation to CI pipeline to prevent regressions + ### When Fixing Bugs - Add regression tests that would have caught the bug - Understand root cause before implementing fix diff --git a/docs/mkdocstrings_helper.py b/docs/mkdocstrings_helper.py index 5560405876a..469bc49444f 100644 --- a/docs/mkdocstrings_helper.py +++ b/docs/mkdocstrings_helper.py @@ -1,8 +1,10 @@ import argparse +import ast import os import subprocess from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple +import re PYDOCSTYLE_CMD = ( "pydocstyle --convention=google --add-ignore=D100,D101,D102," @@ -66,6 +68,147 @@ def generate_title(s: str) -> str: return s +def extract_field_description_from_code(code: str, field_name: str) -> Optional[str]: + """Extract Field description from Python code using AST parsing.""" + try: + tree = ast.parse(code) + + for node in ast.walk(tree): + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + if node.target.id == field_name and isinstance(node.value, ast.Call): + # Check if it's a Field call + if (isinstance(node.value.func, ast.Name) and node.value.func.id == "Field") or \ + (isinstance(node.value.func, ast.Attribute) and node.value.func.attr == "Field"): + + # Extract description from Field arguments + for keyword in node.value.keywords: + if keyword.arg == "description" and isinstance(keyword.value, ast.Constant): + return keyword.value.value + except: + pass + return None + + +def generate_docstring_attributes_from_fields(file_path: Path) -> None: + """Generate docstring attributes section from Pydantic Field descriptions.""" + if not file_path.exists() or not file_path.name.endswith('.py'): + return + + try: + content = file_path.read_text(encoding='utf-8') + + # Skip if no Field imports or pydantic usage + if 'from pydantic import' not in content and 'import pydantic' not in content: + return + + # Parse the file to find classes with Field definitions + tree = ast.parse(content) + modified = False + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + # Check if class has Field definitions + field_descriptions = {} + class_start_line = node.lineno + + # Find Field definitions in the class + for item in node.body: + if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): + field_name = item.target.id + if isinstance(item.value, ast.Call): + # Check if it's a Field call + if (isinstance(item.value.func, ast.Name) and item.value.func.id == "Field") or \ + (isinstance(item.value.func, ast.Attribute) and item.value.func.attr == "Field"): + + # Extract description + for keyword in item.value.keywords: + if keyword.arg == "description" and isinstance(keyword.value, ast.Constant): + field_descriptions[field_name] = keyword.value.value + + # If we found Field descriptions, update the docstring + if field_descriptions: + lines = content.split('\n') + docstring_start, docstring_end = find_class_docstring_range(lines, class_start_line - 1) + + if docstring_start is not None and docstring_end is not None: + # Extract existing docstring + existing_docstring = '\n'.join(lines[docstring_start:docstring_end + 1]) + + # Check if it already has Attributes section + if 'Attributes:' not in existing_docstring: + # Generate attributes section + attributes_section = generate_attributes_section(field_descriptions) + + # Insert before the closing triple quotes + if existing_docstring.strip().endswith('"""'): + # Multi-line docstring + new_docstring = existing_docstring.rstrip('"""').rstrip() + '\n\n' + attributes_section + '\n """' + elif existing_docstring.strip().endswith("'''"): + # Multi-line docstring with single quotes + new_docstring = existing_docstring.rstrip("'''").rstrip() + '\n\n' + attributes_section + "\n '''" + else: + continue + + lines[docstring_start:docstring_end + 1] = new_docstring.split('\n') + modified = True + + if modified: + file_path.write_text('\n'.join(lines), encoding='utf-8') + + except Exception as e: + print(f"Warning: Could not process {file_path}: {e}") + + +def find_class_docstring_range(lines: List[str], class_line: int) -> Tuple[Optional[int], Optional[int]]: + """Find the start and end line numbers of a class docstring.""" + # Look for docstring starting after class definition + for i in range(class_line + 1, min(class_line + 10, len(lines))): + line = lines[i].strip() + if line.startswith('"""') or line.startswith("'''"): + quote_type = '"""' if line.startswith('"""') else "'''" + start_line = i + + # Check if it's a single-line docstring + if line.count(quote_type) >= 2: + return start_line, start_line + + # Find the end of multi-line docstring + for j in range(i + 1, len(lines)): + if quote_type in lines[j]: + return start_line, j + return None, None + + +def generate_attributes_section(field_descriptions: dict) -> str: + """Generate an Attributes section from field descriptions.""" + attributes_lines = [" Attributes:"] + + for field_name, description in field_descriptions.items(): + # Clean up description - remove extra whitespace and line breaks + clean_description = ' '.join(description.split()) + attributes_lines.append(f" {field_name}: {clean_description}") + + return '\n'.join(attributes_lines) + + +def process_pydantic_files_in_directory(directory: Path) -> None: + """Process all Python files in a directory to generate docstring attributes.""" + if not directory.exists(): + return + + print(f"Processing Pydantic files in {directory}...") + + # Find all Python files recursively + python_files = list(directory.rglob("*.py")) + + for file_path in python_files: + # Skip __pycache__ directories and other non-source files + if "__pycache__" in str(file_path) or file_path.name.startswith("_"): + continue + + generate_docstring_attributes_from_fields(file_path) + + def create_entity_docs( api_doc_file_dir: Path, ignored_modules: List[str], @@ -164,6 +307,8 @@ def generate_docs( ignored_modules: A list of modules that should be ignored. validate: Boolean if pydocstyle should be verified within dir """ + # First, process all Pydantic files to generate docstring attributes + process_pydantic_files_in_directory(path) # Set up output paths for the generated md files api_doc_file_dir = output_path / API_DOCS cli_dev_doc_file_dir = output_path / API_DOCS / "cli" diff --git a/src/zenml/artifact_stores/base_artifact_store.py b/src/zenml/artifact_stores/base_artifact_store.py index 41537958431..3ca322d99d4 100644 --- a/src/zenml/artifact_stores/base_artifact_store.py +++ b/src/zenml/artifact_stores/base_artifact_store.py @@ -33,7 +33,7 @@ cast, ) -from pydantic import model_validator +from pydantic import Field, model_validator from zenml.constants import ( ENV_ZENML_SERVER, @@ -187,9 +187,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: class BaseArtifactStoreConfig(StackComponentConfig): - """Config class for `BaseArtifactStore`.""" + """Config class for `BaseArtifactStore`. - path: str + Base configuration for artifact storage backends. + Field descriptions are defined inline using Field() descriptors. + """ + + path: str = Field( + description="Root path for artifact storage. Must be a valid URI supported by the " + "specific artifact store implementation. Examples: 's3://my-bucket/artifacts', " + "'/local/storage/path', 'gs://bucket-name/zenml-artifacts', 'azure://container/path'. " + "Path must be accessible with the configured credentials and permissions" + ) SUPPORTED_SCHEMES: ClassVar[Set[str]] IS_IMMUTABLE_FILESYSTEM: ClassVar[bool] = False diff --git a/src/zenml/container_registries/base_container_registry.py b/src/zenml/container_registries/base_container_registry.py index e7676b31b7b..abbaad757eb 100644 --- a/src/zenml/container_registries/base_container_registry.py +++ b/src/zenml/container_registries/base_container_registry.py @@ -16,7 +16,7 @@ import re from typing import TYPE_CHECKING, Optional, Tuple, Type, cast -from pydantic import field_validator +from pydantic import Field, field_validator from zenml.constants import DOCKER_REGISTRY_RESOURCE_TYPE from zenml.enums import StackComponentType @@ -36,12 +36,24 @@ class BaseContainerRegistryConfig(AuthenticationConfigMixin): """Base config for a container registry. - Attributes: - uri: The URI of the container registry. + Configuration for connecting to container image registries. + Field descriptions are defined inline using Field() descriptors. """ - uri: str - default_repository: Optional[str] = None + uri: str = Field( + description="Container registry URI (e.g., 'gcr.io' for Google Container " + "Registry, 'docker.io' for Docker Hub, 'registry.gitlab.com' for GitLab " + "Container Registry, 'ghcr.io' for GitHub Container Registry). This is " + "the base URL where container images will be pushed to and pulled from." + ) + default_repository: Optional[str] = Field( + default=None, + description="Default repository namespace for image storage (e.g., " + "'username' for Docker Hub, 'project-id' for GCR, 'organization' for " + "GitHub Container Registry). If not specified, images will be stored at " + "the registry root. For Docker Hub this would mean only official images " + "can be pushed.", + ) @field_validator("uri") @classmethod diff --git a/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py b/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py index a3c7cd5ff7c..f7a04930455 100644 --- a/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +++ b/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py @@ -37,100 +37,123 @@ class SagemakerOrchestratorSettings(BaseSettings): - """Settings for the Sagemaker orchestrator. - - Attributes: - synchronous: If `True`, the client running a pipeline using this - orchestrator waits until all steps finish running. If `False`, - the client returns immediately and the pipeline is executed - asynchronously. Defaults to `True`. - instance_type: The instance type to use for the processing job. - execution_role: The IAM role to use for the step execution. - processor_role: DEPRECATED: use `execution_role` instead. - volume_size_in_gb: The size of the EBS volume to use for the processing - job. - max_runtime_in_seconds: The maximum runtime in seconds for the - processing job. - tags: Tags to apply to the Processor/Estimator assigned to the step. - pipeline_tags: Tags to apply to the pipeline via the - sagemaker.workflow.pipeline.Pipeline.create method. - processor_tags: DEPRECATED: use `tags` instead. - keep_alive_period_in_seconds: The time in seconds after which the - provisioned instance will be terminated if not used. This is only - applicable for TrainingStep type and it is not possible to use - TrainingStep type if the `output_data_s3_uri` is set to Dict[str, str]. - use_training_step: Whether to use the TrainingStep type. - It is not possible to use TrainingStep type - if the `output_data_s3_uri` is set to Dict[str, str] or if the - `output_data_s3_mode` != "EndOfJob". - processor_args: Arguments that are directly passed to the SageMaker - Processor for a specific step, allowing for overriding the default - settings provided when configuring the component. See - https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor - for a full list of arguments. - For processor_args.instance_type, check - https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html - for a list of available instance types. - environment: Environment variables to pass to the container. - estimator_args: Arguments that are directly passed to the SageMaker - Estimator for a specific step, allowing for overriding the default - settings provided when configuring the component. See - https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator - for a full list of arguments. - For a list of available instance types, check - https://docs.aws.amazon.com/sagemaker/latest/dg/cmn-info-instance-types.html. - input_data_s3_mode: How data is made available to the container. - Two possible input modes: File, Pipe. - input_data_s3_uri: S3 URI where data is located if not locally, - e.g. s3://my-bucket/my-data/train. How data will be made available - to the container is configured with input_data_s3_mode. Two possible - input types: - - str: S3 location where training data is saved. - - Dict[str, str]: (ChannelName, S3Location) which represent - - Dict[str, str]: (ChannelName, S3Location) which represent - channels (e.g. training, validation, testing) where - specific parts of the data are saved in S3. - output_data_s3_mode: How data is uploaded to the S3 bucket. - Two possible output modes: EndOfJob, Continuous. - output_data_s3_uri: S3 URI where data is uploaded after or during processing run. - e.g. s3://my-bucket/my-data/output. How data will be made available - to the container is configured with output_data_s3_mode. Two possible - input types: - - str: S3 location where data will be uploaded from a local folder - named /opt/ml/processing/output/data. - - Dict[str, str]: (ChannelName, S3Location) which represent - channels (e.g. output_one, output_two) where - specific parts of the data are stored locally for S3 upload. - Data must be available locally in /opt/ml/processing/output/data/. - """ - - synchronous: bool = True + """Settings for the Sagemaker orchestrator.""" + + synchronous: bool = Field( + True, + description="Controls whether pipeline execution blocks the client. If True, " + "the client waits until all steps complete before returning. If False, " + "returns immediately and executes asynchronously. Useful for long-running " + "production pipelines where you don't want to maintain a connection", + ) - instance_type: Optional[str] = None - execution_role: Optional[str] = None - volume_size_in_gb: int = 30 - max_runtime_in_seconds: int = 86400 - tags: Dict[str, str] = {} - pipeline_tags: Dict[str, str] = {} - keep_alive_period_in_seconds: Optional[int] = 300 # 5 minutes - use_training_step: Optional[bool] = None + instance_type: Optional[str] = Field( + None, + description="AWS EC2 instance type for step execution. Must be a valid " + "SageMaker-supported instance type. Examples: 'ml.t3.medium' (2 vCPU, 4GB RAM), " + "'ml.m5.xlarge' (4 vCPU, 16GB RAM), 'ml.p3.2xlarge' (8 vCPU, 61GB RAM, 1 GPU). " + "Defaults to ml.m5.xlarge for training steps or ml.t3.medium for processing steps", + ) + execution_role: Optional[str] = Field( + None, + description="IAM role ARN for SageMaker step execution permissions. Must have " + "necessary policies attached (SageMakerFullAccess, S3 access, etc.). " + "Example: 'arn:aws:iam::123456789012:role/SageMakerExecutionRole'. " + "If not provided, uses the default SageMaker execution role", + ) + volume_size_in_gb: int = Field( + 30, + description="EBS volume size in GB for step execution storage. Must be between " + "1-16384 GB. Used for temporary files, model artifacts, and data processing. " + "Larger volumes needed for big datasets or model training. Example: 30 for " + "small jobs, 100+ for large ML training jobs", + ) + max_runtime_in_seconds: int = Field( + 86400, # 24 hours + description="Maximum execution time in seconds before job termination. Must be " + "between 1-432000 seconds (5 days). Used to prevent runaway jobs and control costs. " + "Examples: 3600 (1 hour), 86400 (24 hours), 259200 (3 days). " + "Consider your longest expected step duration", + ) + tags: Dict[str, str] = Field( + default_factory=dict, + description="Tags to apply to the Processor/Estimator assigned to the step. " + "Example: {'Environment': 'Production', 'Project': 'MLOps'}", + ) + pipeline_tags: Dict[str, str] = Field( + default_factory=dict, + description="Tags to apply to the pipeline via the " + "sagemaker.workflow.pipeline.Pipeline.create method. Example: " + "{'Environment': 'Production', 'Project': 'MLOps'}", + ) + keep_alive_period_in_seconds: Optional[int] = Field( + 300, # 5 minutes + description="The time in seconds after which the provisioned instance " + "will be terminated if not used. This is only applicable for " + "TrainingStep type.", + ) + use_training_step: Optional[bool] = Field( + None, + description="Whether to use the TrainingStep type. It is not possible " + "to use TrainingStep type if the `output_data_s3_uri` is set to " + "Dict[str, str] or if the `output_data_s3_mode` != 'EndOfJob'.", + ) - processor_args: Dict[str, Any] = {} - estimator_args: Dict[str, Any] = {} - environment: Dict[str, str] = {} + processor_args: Dict[str, Any] = Field( + default_factory=dict, + description="Arguments that are directly passed to the SageMaker " + "Processor for a specific step, allowing for overriding the default " + "settings provided when configuring the component. Example: " + "{'instance_count': 2, 'base_job_name': 'my-processing-job'}", + ) + estimator_args: Dict[str, Any] = Field( + default_factory=dict, + description="Arguments that are directly passed to the SageMaker " + "Estimator for a specific step, allowing for overriding the default " + "settings provided when configuring the component. Example: " + "{'train_instance_count': 2, 'train_max_run': 3600}", + ) + environment: Dict[str, str] = Field( + default_factory=dict, + description="Environment variables to pass to the container. " + "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", + ) - input_data_s3_mode: str = "File" + input_data_s3_mode: str = Field( + "File", + description="How data is made available to the container. " + "Two possible input modes: File, Pipe.", + ) input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field( - default=None, union_mode="left_to_right" + default=None, + union_mode="left_to_right", + description="S3 URI where data is located if not locally. Example string: " + "'s3://my-bucket/my-data/train'. Example dict: " + "{'training': 's3://bucket/train', 'validation': 's3://bucket/val'}", ) - output_data_s3_mode: str = DEFAULT_OUTPUT_DATA_S3_MODE + output_data_s3_mode: str = Field( + DEFAULT_OUTPUT_DATA_S3_MODE, + description="How data is uploaded to the S3 bucket. " + "Two possible output modes: EndOfJob, Continuous.", + ) output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field( - default=None, union_mode="left_to_right" + default=None, + union_mode="left_to_right", + description="S3 URI where data is uploaded after or during processing run. " + "Example string: 's3://my-bucket/my-data/output'. Example dict: " + "{'output_one': 's3://bucket/out1', 'output_two': 's3://bucket/out2'}", + ) + processor_role: Optional[str] = Field( + None, + description="DEPRECATED: use `execution_role` instead. " + "The IAM role to use for the step execution.", + ) + processor_tags: Optional[Dict[str, str]] = Field( + None, + description="DEPRECATED: use `tags` instead. " + "Tags to apply to the Processor assigned to the step.", ) - - processor_role: Optional[str] = None - processor_tags: Optional[Dict[str, str]] = None _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( ("processor_role", "execution_role"), ("processor_tags", "tags") ) @@ -184,39 +207,49 @@ class SagemakerOrchestratorConfig( `aws_secret_access_key`, and optional `aws_auth_role_arn`, - If none of the above are provided, unspecified credentials will be loaded from the default AWS config. - - Attributes: - execution_role: The IAM role ARN to use for the pipeline. - scheduler_role: The ARN of the IAM role that will be assumed by - the EventBridge service to launch Sagemaker pipelines - (For more details regarding the required permissions, please check: - https://docs.zenml.io/stacks/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules) - aws_access_key_id: The AWS access key ID to use to authenticate to AWS. - If not provided, the value from the default AWS config will be used. - aws_secret_access_key: The AWS secret access key to use to authenticate - to AWS. If not provided, the value from the default AWS config will - be used. - aws_profile: The AWS profile to use for authentication if not using - service connectors or explicit credentials. If not provided, the - default profile will be used. - aws_auth_role_arn: The ARN of an intermediate IAM role to assume when - authenticating to AWS. - region: The AWS region where the processing job will be run. If not - provided, the value from the default AWS config will be used. - bucket: Name of the S3 bucket to use for storing artifacts - from the job run. If not provided, a default bucket will be created - based on the following format: - "sagemaker-{region}-{aws-account-id}". """ - execution_role: str - scheduler_role: Optional[str] = None - aws_access_key_id: Optional[str] = SecretField(default=None) - aws_secret_access_key: Optional[str] = SecretField(default=None) - aws_profile: Optional[str] = None - aws_auth_role_arn: Optional[str] = None - region: Optional[str] = None - bucket: Optional[str] = None + execution_role: str = Field( + ..., description="The IAM role ARN to use for the pipeline." + ) + scheduler_role: Optional[str] = Field( + None, + description="The ARN of the IAM role that will be assumed by " + "the EventBridge service to launch Sagemaker pipelines. " + "Required for scheduled pipelines.", + ) + aws_access_key_id: Optional[str] = SecretField( + default=None, + description="The AWS access key ID to use to authenticate to AWS. " + "If not provided, the value from the default AWS config will be used.", + ) + aws_secret_access_key: Optional[str] = SecretField( + default=None, + description="The AWS secret access key to use to authenticate to AWS. " + "If not provided, the value from the default AWS config will be used.", + ) + aws_profile: Optional[str] = Field( + None, + description="The AWS profile to use for authentication if not using " + "service connectors or explicit credentials. If not provided, the " + "default profile will be used.", + ) + aws_auth_role_arn: Optional[str] = Field( + None, + description="The ARN of an intermediate IAM role to assume when " + "authenticating to AWS.", + ) + region: Optional[str] = Field( + None, + description="The AWS region where the processing job will be run. " + "If not provided, the value from the default AWS config will be used.", + ) + bucket: Optional[str] = Field( + None, + description="Name of the S3 bucket to use for storing artifacts " + "from the job run. If not provided, a default bucket will be created " + "based on the following format: 'sagemaker-{region}-{aws-account-id}'.", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py index 51cc22377a2..8c43737b67e 100644 --- a/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py @@ -34,38 +34,37 @@ class SagemakerStepOperatorSettings(BaseSettings): - """Settings for the Sagemaker step operator. - - Attributes: - experiment_name: The name for the experiment to which the job - will be associated. If not provided, the job runs would be - independent. - input_data_s3_uri: S3 URI where training data is located if not locally, - e.g. s3://my-bucket/my-data/train. How data will be made available - to the container is configured with estimator_args.input_mode. Two possible - input types: - - str: S3 location where training data is saved. - - Dict[str, str]: (ChannelName, S3Location) which represent - channels (e.g. training, validation, testing) where - specific parts of the data are saved in S3. - estimator_args: Arguments that are directly passed to the SageMaker - Estimator. See - https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator - for a full list of arguments. - For estimator_args.instance_type, check - https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html - for a list of available instance types. - environment: Environment variables to pass to the container. - - """ - - instance_type: Optional[str] = None - experiment_name: Optional[str] = None + """Settings for the Sagemaker step operator.""" + + instance_type: Optional[str] = Field( + None, + description="DEPRECATED: The instance type to use for the step execution. " + "Use estimator_args instead. Example: 'ml.m5.xlarge'", + ) + experiment_name: Optional[str] = Field( + None, + description="The name for the experiment to which the job will be associated. " + "If not provided, the job runs would be independent. Example: 'my-training-experiment'", + ) input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field( - default=None, union_mode="left_to_right" + default=None, + union_mode="left_to_right", + description="S3 URI where training data is located if not locally. " + "Example string: 's3://my-bucket/my-data/train'. Example dict: " + "{'training': 's3://bucket/train', 'validation': 's3://bucket/val'}", + ) + estimator_args: Dict[str, Any] = Field( + default_factory=dict, + description="Arguments that are directly passed to the SageMaker Estimator. " + "See SageMaker documentation for available arguments and instance types. Example: " + "{'instance_type': 'ml.m5.xlarge', 'instance_count': 1, " + "'train_max_run': 3600, 'input_mode': 'File'}", + ) + environment: Dict[str, str] = Field( + default_factory=dict, + description="Environment variables to pass to the container during execution. " + "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) - estimator_args: Dict[str, Any] = {} - environment: Dict[str, str] = {} _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( "instance_type" @@ -75,18 +74,20 @@ class SagemakerStepOperatorSettings(BaseSettings): class SagemakerStepOperatorConfig( BaseStepOperatorConfig, SagemakerStepOperatorSettings ): - """Config for the Sagemaker step operator. - - Attributes: - role: The role that has to be assigned to the jobs which are - running in Sagemaker. - bucket: Name of the S3 bucket to use for storing artifacts - from the job run. If not provided, a default bucket will be created - based on the following format: "sagemaker-{region}-{aws-account-id}". - """ - - role: str - bucket: Optional[str] = None + """Config for the Sagemaker step operator.""" + + role: str = Field( + ..., + description="The IAM role ARN that has to be assigned to the jobs " + "running in SageMaker. This role must have the necessary permissions " + "to access SageMaker and S3 resources.", + ) + bucket: Optional[str] = Field( + None, + description="Name of the S3 bucket to use for storing artifacts from the job run. " + "If not provided, a default bucket will be created based on the format: " + "'sagemaker-{region}-{aws-account-id}'.", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py b/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py index 67daa7c4164..e99ada72d3f 100644 --- a/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py +++ b/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.integrations.bentoml import BENTOML_MODEL_DEPLOYER_FLAVOR from zenml.model_deployers.base_model_deployer import ( BaseModelDeployerConfig, @@ -28,7 +30,11 @@ class BentoMLModelDeployerConfig(BaseModelDeployerConfig): """Configuration for the BentoMLModelDeployer.""" - service_path: str = "" + service_path: str = Field( + "", + description="Path to the BentoML service directory. " + "If not provided, a default service path will be used.", + ) class BentoMLModelDeployerFlavor(BaseModelDeployerFlavor): diff --git a/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py b/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py index e8f8f3fe8ef..af9540a88e5 100644 --- a/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +++ b/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.integrations.databricks import DATABRICKS_ORCHESTRATOR_FLAVOR from zenml.logger import get_logger @@ -42,31 +44,64 @@ class DatabricksAvailabilityType(StrEnum): class DatabricksOrchestratorSettings(BaseSettings): """Databricks orchestrator base settings. - Attributes: - spark_version: Spark version. - num_workers: Number of workers. - node_type_id: Node type id. - policy_id: Policy id. - autotermination_minutes: Autotermination minutes. - autoscale: Autoscale. - single_user_name: Single user name. - spark_conf: Spark configuration. - spark_env_vars: Spark environment variables. - schedule_timezone: Schedule timezone. + Configuration for Databricks cluster and Spark execution settings. + Field descriptions are defined inline using Field() descriptors. """ - # Resources - spark_version: Optional[str] = None - num_workers: Optional[int] = None - node_type_id: Optional[str] = None - policy_id: Optional[str] = None - autotermination_minutes: Optional[int] = None - autoscale: Tuple[int, int] = (0, 1) - single_user_name: Optional[str] = None - spark_conf: Optional[Dict[str, str]] = None - spark_env_vars: Optional[Dict[str, str]] = None - schedule_timezone: Optional[str] = None - availability_type: Optional[DatabricksAvailabilityType] = None + # Cluster Configuration + spark_version: Optional[str] = Field( + default=None, + description="Apache Spark version for the Databricks cluster. " + "Uses workspace default if not specified. Example: '3.2.x-scala2.12'", + ) + num_workers: Optional[int] = Field( + default=None, + description="Fixed number of worker nodes. Cannot be used with autoscaling.", + ) + node_type_id: Optional[str] = Field( + default=None, + description="Databricks node type identifier. " + "Refer to Databricks documentation for available instance types. " + "Example: 'i3.xlarge'", + ) + policy_id: Optional[str] = Field( + default=None, + description="Databricks cluster policy ID for governance and cost control.", + ) + autotermination_minutes: Optional[int] = Field( + default=None, + description="Minutes of inactivity before automatic cluster termination. " + "Helps control costs by shutting down idle clusters.", + ) + autoscale: Tuple[int, int] = Field( + default=(0, 1), + description="Cluster autoscaling bounds as (min_workers, max_workers). " + "Automatically adjusts cluster size based on workload.", + ) + single_user_name: Optional[str] = Field( + default=None, + description="Databricks username for single-user cluster access mode.", + ) + spark_conf: Optional[Dict[str, str]] = Field( + default=None, + description="Custom Spark configuration properties as key-value pairs. " + "Example: {'spark.sql.adaptive.enabled': 'true', 'spark.sql.adaptive.coalescePartitions.enabled': 'true'}", + ) + spark_env_vars: Optional[Dict[str, str]] = Field( + default=None, + description="Environment variables for the Spark driver and executors. " + "Example: {'SPARK_WORKER_MEMORY': '4g', 'SPARK_DRIVER_MEMORY': '2g'}", + ) + schedule_timezone: Optional[str] = Field( + default=None, + description="Timezone for scheduled pipeline execution. " + "Uses IANA timezone format (e.g., 'America/New_York').", + ) + availability_type: Optional[DatabricksAvailabilityType] = Field( + default=None, + description="Instance availability type: ON_DEMAND (guaranteed), SPOT (cost-optimized), " + "or SPOT_WITH_FALLBACK (spot with on-demand backup).", + ) class DatabricksOrchestratorConfig( diff --git a/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py b/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py index 6dd25d8464b..2180d82c8fc 100644 --- a/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py +++ b/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.feature_stores.base_feature_store import ( BaseFeatureStoreConfig, BaseFeatureStoreFlavor, @@ -26,11 +28,22 @@ class FeastFeatureStoreConfig(BaseFeatureStoreConfig): - """Config for Feast feature store.""" - - online_host: str = "localhost" - online_port: int = 6379 - feast_repo: str + """Config for Feast feature store. + + Configuration for connecting to Feast feature stores. + Field descriptions are defined inline using Field() descriptors. + """ + + online_host: str = Field( + default="localhost", + description="Online feature store host address (typically Redis server).", + ) + online_port: int = Field( + default=6379, description="Online feature store port number." + ) + feast_repo: str = Field( + description="Local filesystem path to the Feast repository with feature definitions." + ) @property def is_local(self) -> bool: diff --git a/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py index 49dbf89534f..40d8e0c8aa1 100644 --- a/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py +++ b/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py @@ -16,7 +16,7 @@ import re from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union -from pydantic import field_validator +from pydantic import Field, field_validator from zenml.config.base_settings import BaseSettings from zenml.experiment_trackers.base_experiment_tracker import ( @@ -40,15 +40,15 @@ class VertexExperimentTrackerSettings(BaseSettings): - """Settings for the VertexAI experiment tracker. + """Settings for the VertexAI experiment tracker.""" - Attributes: - experiment: The VertexAI experiment name. - experiment_tensorboard: The VertexAI experiment tensorboard. - """ - - experiment: Optional[str] = None - experiment_tensorboard: Optional[Union[str, bool]] = None + experiment: Optional[str] = Field( + None, description="The VertexAI experiment name." + ) + experiment_tensorboard: Optional[Union[str, bool]] = Field( + None, + description="The VertexAI experiment tensorboard instance to use.", + ) @field_validator("experiment", mode="before") def _validate_experiment(cls, value: str) -> str: @@ -76,39 +76,7 @@ class VertexExperimentTrackerConfig( GoogleCredentialsConfigMixin, VertexExperimentTrackerSettings, ): - """Config for the VertexAI experiment tracker. - - Attributes: - location: Optional. The default location to use when making API calls. If not - set defaults to us-central1. - staging_bucket: Optional. The default staging bucket to use to stage artifacts - when making API calls. In the form gs://... - network: - Optional. The full name of the Compute Engine network to which jobs - and resources should be peered. E.g. "projects/12345/global/networks/myVPC". - Private services access must already be configured for the network. - If specified, all eligible jobs and resources created will be peered - with this VPC. - encryption_spec_key_name: - Optional. The Cloud KMS resource identifier of the customer - managed encryption key used to protect a resource. Has the - form: - ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. - The key needs to be in the same region as where the compute - resource is created. - api_endpoint (str): - Optional. The desired API endpoint, - e.g., us-central1-aiplatform.googleapis.com - api_key (str): - Optional. The API key to use for service calls. - NOTE: Not all services support API keys. - api_transport (str): - Optional. The transport method which is either 'grpc' or 'rest'. - NOTE: "rest" transport functionality is currently in a - beta state (preview). - request_metadata: - Optional. Additional gRPC metadata to send with every client request. - """ + """Config for the VertexAI experiment tracker.""" location: Optional[str] = None staging_bucket: Optional[str] = None diff --git a/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py index 04edcbc0e12..101981c6c04 100644 --- a/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +++ b/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.integrations.gcp import ( GCP_RESOURCE_TYPE, @@ -36,35 +38,37 @@ class VertexOrchestratorSettings(BaseSettings): - """Settings for the Vertex orchestrator. - - Attributes: - synchronous: If `True`, the client running a pipeline using this - orchestrator waits until all steps finish running. If `False`, - the client returns immediately and the pipeline is executed - asynchronously. Defaults to `True`. - labels: Labels to assign to the pipeline job. - node_selector_constraint: Each constraint is a key-value pair label. - For the container to be eligible to run on a node, the node must have - each of the constraints appeared as labels. - For example a GPU type can be providing by one of the following tuples: - - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_A100") - - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_K80") - - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P4") - - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_P100") - - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_T4") - - ("cloud.google.com/gke-accelerator", "NVIDIA_TESLA_V100") - Hint: the selected region (location) must provide the requested accelerator - (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones). - pod_settings: Pod settings to apply. - """ - - labels: Dict[str, str] = {} - synchronous: bool = True - node_selector_constraint: Optional[Tuple[str, str]] = None - pod_settings: Optional[KubernetesPodSettings] = None - - custom_job_parameters: Optional[VertexCustomJobParameters] = None + """Settings for the Vertex orchestrator.""" + + labels: Dict[str, str] = Field( + default_factory=dict, + description="Labels to assign to the pipeline job. " + "Example: {'environment': 'production', 'team': 'ml-ops'}", + ) + synchronous: bool = Field( + True, + description="If `True`, the client running a pipeline using this " + "orchestrator waits until all steps finish running. If `False`, " + "the client returns immediately and the pipeline is executed " + "asynchronously.", + ) + node_selector_constraint: Optional[Tuple[str, str]] = Field( + None, + description="Each constraint is a key-value pair label. For the container " + "to be eligible to run on a node, the node must have each of the " + "constraints appeared as labels. For example, a GPU type can be provided " + "by ('cloud.google.com/gke-accelerator', 'NVIDIA_TESLA_T4')." + "Hint: the selected region (location) must provide the requested accelerator" + "(see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).", + ) + pod_settings: Optional[KubernetesPodSettings] = Field( + None, + description="Pod settings to apply to the orchestrator and step pods.", + ) + + custom_job_parameters: Optional[VertexCustomJobParameters] = Field( + None, description="Custom parameters for the Vertex AI custom job." + ) _node_selector_deprecation = ( deprecation_utils.deprecate_pydantic_attributes( @@ -78,71 +82,74 @@ class VertexOrchestratorConfig( GoogleCredentialsConfigMixin, VertexOrchestratorSettings, ): - """Configuration for the Vertex orchestrator. - - Attributes: - location: Name of GCP region where the pipeline job will be executed. - Vertex AI Pipelines is available in the following regions: - https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability - pipeline_root: a Cloud Storage URI that will be used by the Vertex AI - Pipelines. If not provided but the artifact store in the stack used - to execute the pipeline is a - `zenml.integrations.gcp.artifact_stores.GCPArtifactStore`, - then a subdirectory of the artifact store will be used. - encryption_spec_key_name: The Cloud KMS resource identifier of the - customer managed encryption key used to protect the job. Has the form: - `projects//locations//keyRings//cryptoKeys/` - . The key needs to be in the same region as where the compute - resource is created. - workload_service_account: the service account for workload run-as - account. Users submitting jobs must have act-as permission on this - run-as account. If not provided, the Compute Engine default service - account for the GCP project in which the pipeline is running is - used. - function_service_account: the service account for cloud function run-as - account, for scheduled pipelines. This service account must have - the act-as permission on the workload_service_account. - If not provided, the Compute Engine default service account for the - GCP project in which the pipeline is running is used. - scheduler_service_account: the service account used by the Google Cloud - Scheduler to trigger and authenticate to the pipeline Cloud Function - on a schedule. If not provided, the Compute Engine default service - account for the GCP project in which the pipeline is running is - used. - network: the full name of the Compute Engine Network to which the job - should be peered. For example, `projects/12345/global/networks/myVPC` - If not provided, the job will not be peered with any network. - private_service_connect: the full name of a Private Service Connect - endpoint to which the job should be peered. For example, - `projects/12345/regions/us-central1/networkAttachments/NETWORK_ATTACHMENT_NAME` - If not provided, the job will not be peered with any private service - connect endpoint. - cpu_limit: The maximum CPU limit for this operator. This string value - can be a number (integer value for number of CPUs) as string, - or a number followed by "m", which means 1/1000. You can specify - at most 96 CPUs. - (see. https://cloud.google.com/vertex-ai/docs/pipelines/machine-types) - memory_limit: The maximum memory limit for this operator. This string - value can be a number, or a number followed by "K" (kilobyte), - "M" (megabyte), or "G" (gigabyte). At most 624GB is supported. - gpu_limit: The GPU limit (positive number) for the operator. - For more information about GPU resources, see: - https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus - """ - - location: str - pipeline_root: Optional[str] = None - encryption_spec_key_name: Optional[str] = None - workload_service_account: Optional[str] = None - network: Optional[str] = None - private_service_connect: Optional[str] = None + """Configuration for the Vertex orchestrator.""" + + location: str = Field( + ..., + description="Name of GCP region where the pipeline job will be executed. " + "Vertex AI Pipelines is available in specific regions: " + "https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability", + ) + pipeline_root: Optional[str] = Field( + None, + description="A Cloud Storage URI that will be used by the Vertex AI Pipelines. " + "If not provided but the artifact store in the stack is a GCPArtifactStore, " + "then a subdirectory of the artifact store will be used.", + ) + encryption_spec_key_name: Optional[str] = Field( + None, + description="The Cloud KMS resource identifier of the customer managed " + "encryption key used to protect the job. Has the form: " + "projects//locations//keyRings//cryptoKeys/. " + "The key needs to be in the same region as where the compute resource is created.", + ) + workload_service_account: Optional[str] = Field( + None, + description="The service account for workload run-as account. Users submitting " + "jobs must have act-as permission on this run-as account. If not provided, " + "the Compute Engine default service account for the GCP project is used.", + ) + network: Optional[str] = Field( + None, + description="The full name of the Compute Engine Network to which the job " + "should be peered. For example, 'projects/12345/global/networks/myVPC'. " + "If not provided, the job will not be peered with any network.", + ) + private_service_connect: Optional[str] = Field( + None, + description="The full name of a Private Service Connect endpoint to which " + "the job should be peered. For example, " + "'projects/12345/regions/us-central1/networkAttachments/NETWORK_ATTACHMENT_NAME'. " + "If not provided, the job will not be peered with any private service connect endpoint.", + ) # Deprecated - cpu_limit: Optional[str] = None - memory_limit: Optional[str] = None - gpu_limit: Optional[int] = None - function_service_account: Optional[str] = None - scheduler_service_account: Optional[str] = None + cpu_limit: Optional[str] = Field( + None, + description="DEPRECATED: The maximum CPU limit for this operator. " + "Use custom_job_parameters or pod_settings instead.", + ) + memory_limit: Optional[str] = Field( + None, + description="DEPRECATED: The maximum memory limit for this operator. " + "Use custom_job_parameters or pod_settings instead.", + ) + gpu_limit: Optional[int] = Field( + None, + description="DEPRECATED: The GPU limit for the operator. " + "Use custom_job_parameters or pod_settings instead.", + ) + function_service_account: Optional[str] = Field( + None, + description="DEPRECATED: The service account for cloud function run-as account, " + "for scheduled pipelines. This functionality is no longer supported.", + ) + scheduler_service_account: Optional[str] = Field( + None, + description="DEPRECATED: The service account used by the Google Cloud Scheduler " + "to trigger and authenticate to the pipeline Cloud Function on a schedule. " + "This functionality is no longer supported.", + ) _resource_deprecation = deprecation_utils.deprecate_pydantic_attributes( "cpu_limit", diff --git a/src/zenml/integrations/gcp/google_credentials_mixin.py b/src/zenml/integrations/gcp/google_credentials_mixin.py index 3bf6e25de08..681bbf4da66 100644 --- a/src/zenml/integrations/gcp/google_credentials_mixin.py +++ b/src/zenml/integrations/gcp/google_credentials_mixin.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast +from pydantic import Field + from zenml.logger import get_logger from zenml.stack.stack_component import StackComponent, StackComponentConfig @@ -28,16 +30,19 @@ class GoogleCredentialsConfigMixin(StackComponentConfig): """Config mixin for Google Cloud Platform credentials. - Attributes: - project: GCP project name. If `None`, the project will be inferred from - the environment. - service_account_path: path to the service account credentials file to be - used for authentication. If not provided, the default credentials - will be used. + Provides common GCP authentication configuration for stack components. + Field descriptions are defined inline using Field() descriptors. """ - project: Optional[str] = None - service_account_path: Optional[str] = None + project: Optional[str] = Field( + default=None, + description="Google Cloud Project ID. Auto-detected from environment if not specified.", + ) + service_account_path: Optional[str] = Field( + default=None, + description="Path to service account JSON key file for authentication. " + "Uses Application Default Credentials if not provided.", + ) class GoogleCredentialsMixin(StackComponent): diff --git a/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py b/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py index dc5bb2a1660..be7b19d33cc 100644 --- a/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +++ b/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.integrations.hyperai import HYPERAI_RESOURCE_TYPE from zenml.logger import get_logger @@ -31,42 +33,38 @@ class HyperAIOrchestratorSettings(BaseSettings): - """HyperAI orchestrator settings. - - Attributes: - mounts_from_to: A dictionary mapping from paths on the HyperAI instance - to paths within the Docker container. This allows users to mount - directories from the HyperAI instance into the Docker container that runs - on it. - """ + """HyperAI orchestrator settings.""" - mounts_from_to: Dict[str, str] = {} + mounts_from_to: Dict[str, str] = Field( + default_factory=dict, + description="A dictionary mapping from paths on the HyperAI instance " + "to paths within the Docker container. This allows users to mount " + "directories from the HyperAI instance into the Docker container.", + ) class HyperAIOrchestratorConfig( BaseOrchestratorConfig, HyperAIOrchestratorSettings ): - """Configuration for the HyperAI orchestrator. - - Attributes: - container_registry_autologin: If True, the orchestrator will attempt to - automatically log in to the container registry specified in the stack - configuration on the HyperAI instance. This is useful if the container - registry requires authentication and the HyperAI instance has not been - manually logged in to the container registry. Defaults to `False`. - automatic_cleanup_pipeline_files: If True, the orchestrator will - automatically clean up old pipeline files that are on the HyperAI - instance. Pipeline files will be cleaned up if they are 7 days old or - older. Defaults to `True`. - gpu_enabled_in_container: If True, the orchestrator will enable GPU - support in the Docker container that runs on the HyperAI instance. - Defaults to `True`. - - """ - - container_registry_autologin: bool = False - automatic_cleanup_pipeline_files: bool = True - gpu_enabled_in_container: bool = True + """Configuration for the HyperAI orchestrator.""" + + container_registry_autologin: bool = Field( + False, + description="If True, the orchestrator will attempt to automatically " + "log in to the container registry specified in the stack configuration " + "on the HyperAI instance.", + ) + automatic_cleanup_pipeline_files: bool = Field( + True, + description="If True, the orchestrator will automatically clean up old " + "pipeline files that are on the HyperAI instance. Pipeline files will " + "be cleaned up if they are 7 days old or older.", + ) + gpu_enabled_in_container: bool = Field( + True, + description="If True, the orchestrator will enable GPU support in the " + "Docker container that runs on the HyperAI instance.", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py index 803e42d1519..b7733bd29a4 100644 --- a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py +++ b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type -from pydantic import PositiveInt +from pydantic import Field, PositiveInt from zenml.image_builders import BaseImageBuilderConfig, BaseImageBuilderFlavor from zenml.integrations.kaniko import KANIKO_IMAGE_BUILDER_FLAVOR @@ -39,47 +39,63 @@ class KanikoImageBuilderConfig(BaseImageBuilderConfig): configure secrets and environment variables so that the Kaniko build container is able to push to the container registry (and optionally access the artifact store to upload the build context). - - Attributes: - kubernetes_context: The Kubernetes context in which to run the Kaniko - pod. - kubernetes_namespace: The Kubernetes namespace in which to run the - Kaniko pod. This namespace will not be created and must already - exist. - executor_image: The image of the Kaniko executor to use. - pod_running_timeout: The timeout to wait until the pod is running - in seconds. Defaults to `300`. - env: `env` section of the Kubernetes container spec. - env_from: `envFrom` section of the Kubernetes container spec. - volume_mounts: `volumeMounts` section of the Kubernetes container spec. - volumes: `volumes` section of the Kubernetes pod spec. - service_account_name: Name of the Kubernetes service account to use. - store_context_in_artifact_store: If `True`, the build context will be - stored in the artifact store. If `False`, the build context will be - streamed over stdin of the `kubectl` process that runs the build. - In case the artifact store is used, the container running the build - needs read access to the artifact store. - executor_args: Additional arguments to forward to the Kaniko executor. - See https://github.com/GoogleContainerTools/kaniko#additional-flags - for a full list of available arguments. - Example: `["--compressed-caching=false"]` - """ - kubernetes_context: str - kubernetes_namespace: str = "zenml-kaniko" - executor_image: str = DEFAULT_KANIKO_EXECUTOR_IMAGE - pod_running_timeout: PositiveInt = DEFAULT_KANIKO_POD_RUNNING_TIMEOUT - - env: List[Dict[str, Any]] = [] - env_from: List[Dict[str, Any]] = [] - volume_mounts: List[Dict[str, Any]] = [] - volumes: List[Dict[str, Any]] = [] - service_account_name: Optional[str] = None - - store_context_in_artifact_store: bool = False - - executor_args: List[str] = [] + kubernetes_context: str = Field( + ..., + description="The Kubernetes context in which to run the Kaniko pod.", + ) + kubernetes_namespace: str = Field( + "zenml-kaniko", + description="The Kubernetes namespace in which to run the Kaniko pod. " + "This namespace will not be created and must already exist.", + ) + executor_image: str = Field( + DEFAULT_KANIKO_EXECUTOR_IMAGE, + description="The image of the Kaniko executor to use for building container images.", + ) + pod_running_timeout: PositiveInt = Field( + DEFAULT_KANIKO_POD_RUNNING_TIMEOUT, + description="The timeout to wait until the pod is running in seconds.", + ) + + env: List[Dict[str, Any]] = Field( + default_factory=list, + description="Environment variables section of the Kubernetes container spec. " + "Used to configure secrets and environment variables for registry access.", + ) + env_from: List[Dict[str, Any]] = Field( + default_factory=list, + description="EnvFrom section of the Kubernetes container spec. " + "Used to load environment variables from ConfigMaps or Secrets.", + ) + volume_mounts: List[Dict[str, Any]] = Field( + default_factory=list, + description="VolumeMounts section of the Kubernetes container spec. " + "Used to mount volumes containing credentials or other data.", + ) + volumes: List[Dict[str, Any]] = Field( + default_factory=list, + description="Volumes section of the Kubernetes pod spec. " + "Used to define volumes for credentials or other data.", + ) + service_account_name: Optional[str] = Field( + None, + description="Name of the Kubernetes service account to use for the Kaniko pod. " + "This service account should have the necessary permissions for building and pushing images.", + ) + + store_context_in_artifact_store: bool = Field( + False, + description="If `True`, the build context will be stored in the artifact store. " + "If `False`, the build context will be streamed over stdin of the kubectl process.", + ) + + executor_args: List[str] = Field( + default_factory=list, + description="Additional arguments to forward to the Kaniko executor. " + "See Kaniko documentation for available flags, e.g. ['--compressed-caching=false'].", + ) class KanikoImageBuilderFlavor(BaseImageBuilderFlavor): diff --git a/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py b/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py index f0225c79182..eb218a283e5 100644 --- a/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +++ b/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast -from pydantic import model_validator +from pydantic import Field, model_validator from zenml.config.base_settings import BaseSettings from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE @@ -36,34 +36,43 @@ class KubeflowOrchestratorSettings(BaseSettings): - """Settings for the Kubeflow orchestrator. - - Attributes: - synchronous: If `True`, the client running a pipeline using this - orchestrator waits until all steps finish running. If `False`, - the client returns immediately and the pipeline is executed - asynchronously. Defaults to `True`. This setting only - has an effect when specified on the pipeline and will be ignored if - specified on steps. - timeout: How many seconds to wait for synchronous runs. - client_args: Arguments to pass when initializing the KFP client. - client_username: Username to generate a session cookie for the kubeflow client. Both `client_username` - and `client_password` need to be set together. - client_password: Password to generate a session cookie for the kubeflow client. Both `client_username` - and `client_password` need to be set together. - user_namespace: The user namespace to use when creating experiments - and runs. - pod_settings: Pod settings to apply. - """ - - synchronous: bool = True - timeout: int = 1200 - - client_args: Dict[str, Any] = {} - client_username: Optional[str] = SecretField(default=None) - client_password: Optional[str] = SecretField(default=None) - user_namespace: Optional[str] = None - pod_settings: Optional[KubernetesPodSettings] = None + """Settings for the Kubeflow orchestrator.""" + + synchronous: bool = Field( + True, + description="If `True`, the client running a pipeline using this " + "orchestrator waits until all steps finish running. If `False`, " + "the client returns immediately and the pipeline is executed " + "asynchronously.", + ) + timeout: int = Field( + 1200, description="How many seconds to wait for synchronous runs." + ) + + client_args: Dict[str, Any] = Field( + default_factory=dict, + description="Arguments to pass when initializing the KFP client. " + "Example: {'host': 'https://kubeflow.example.com', 'client_id': 'kubeflow-oidc-authservice', 'existing_token': 'your-auth-token'}", + ) + client_username: Optional[str] = SecretField( + default=None, + description="Username to generate a session cookie for the kubeflow client. " + "Both `client_username` and `client_password` need to be set together.", + ) + client_password: Optional[str] = SecretField( + default=None, + description="Password to generate a session cookie for the kubeflow client. " + "Both `client_username` and `client_password` need to be set together.", + ) + user_namespace: Optional[str] = Field( + None, + description="The user namespace to use when creating experiments and runs. " + "Example: 'my-experiments' or 'team-alpha'", + ) + pod_settings: Optional[KubernetesPodSettings] = Field( + None, + description="Pod settings to apply to the orchestrator and step pods.", + ) @model_validator(mode="before") @classmethod @@ -131,25 +140,27 @@ def _validate_and_migrate_pod_settings( class KubeflowOrchestratorConfig( BaseOrchestratorConfig, KubeflowOrchestratorSettings ): - """Configuration for the Kubeflow orchestrator. - - Attributes: - kubeflow_hostname: The hostname to use to talk to the Kubeflow Pipelines - API. If not set, the hostname will be derived from the Kubernetes - API proxy. Mandatory when connecting to a multi-tenant Kubeflow - Pipelines deployment. - kubeflow_namespace: The Kubernetes namespace in which Kubeflow - Pipelines is deployed. Defaults to `kubeflow`. - kubernetes_context: Name of a kubernetes context to run - pipelines in. Not applicable when connecting to a multi-tenant - Kubeflow Pipelines deployment (i.e. when `kubeflow_hostname` is - set) or if the stack component is linked to a Kubernetes service - connector. - """ - - kubeflow_hostname: Optional[str] = None - kubeflow_namespace: str = "kubeflow" - kubernetes_context: Optional[str] = None # TODO: Potential setting + """Configuration for the Kubeflow orchestrator.""" + + kubeflow_hostname: Optional[str] = Field( + None, + description="The hostname to use to talk to the Kubeflow Pipelines API. " + "If not set, the hostname will be derived from the Kubernetes API proxy. " + "Mandatory when connecting to a multi-tenant Kubeflow Pipelines deployment. " + "Example: 'https://kubeflow.example.com' or 'kubeflow.company.com'", + ) + kubeflow_namespace: str = Field( + "kubeflow", + description="The Kubernetes namespace in which Kubeflow Pipelines is deployed.", + ) + kubernetes_context: Optional[str] = Field( + None, + description="Name of a kubernetes context to run pipelines in. " + "Not applicable when connecting to a multi-tenant Kubeflow Pipelines " + "deployment (i.e. when `kubeflow_hostname` is set) or if the stack " + "component is linked to a Kubernetes service connector. " + "Example: 'my-cluster' or 'production-cluster'", + ) @model_validator(mode="before") @classmethod diff --git a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py index d145b580794..00e4b928e23 100644 --- a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +++ b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Type -from pydantic import NonNegativeInt, PositiveInt, field_validator +from pydantic import Field, NonNegativeInt, PositiveInt, field_validator from zenml.config.base_settings import BaseSettings from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE @@ -33,92 +33,122 @@ class KubernetesOrchestratorSettings(BaseSettings): """Settings for the Kubernetes orchestrator. - Attributes: - synchronous: If `True`, the client running a pipeline using this - orchestrator waits until all steps finish running. If `False`, - the client returns immediately and the pipeline is executed - asynchronously. Defaults to `True`. - timeout: How many seconds to wait for synchronous runs. `0` means - to wait for an unlimited duration. - stream_step_logs: If `True`, the orchestrator pod will stream the logs - of the step pods. This only has an effect if specified on the - pipeline, not on individual steps. - service_account_name: Name of the service account to use for the - orchestrator pod. If not provided, a new service account with "edit" - permissions will be created. - step_pod_service_account_name: Name of the service account to use for the - step pods. If not provided, the default service account will be used. - privileged: If the container should be run in privileged mode. - pod_settings: Pod settings to apply to pods executing the steps. - orchestrator_pod_settings: Pod settings to apply to the pod which is - launching the actual steps. - pod_name_prefix: Prefix to use for the pod name. - pod_startup_timeout: The maximum time to wait for a pending step pod to - start (in seconds). - pod_failure_max_retries: The maximum number of times to retry a step - pod if the step Kubernetes pod fails to start - pod_failure_retry_delay: The delay in seconds between pod - failure retries and pod startup retries (in seconds) - pod_failure_backoff: The backoff factor for pod failure retries and - pod startup retries. - max_parallelism: Maximum number of steps to run in parallel. - successful_jobs_history_limit: The number of successful jobs - to retain. This only applies to jobs created when scheduling a - pipeline. - failed_jobs_history_limit: The number of failed jobs to retain. - This only applies to jobs created when scheduling a pipeline. - ttl_seconds_after_finished: The amount of seconds to keep finished jobs - before deleting them. **Note**: This does not clean up the - orchestrator pod for non-scheduled runs. - active_deadline_seconds: The active deadline seconds for the job that is - executing the step. - backoff_limit_margin: The value to add to the backoff limit in addition - to the step retries. The retry configuration defined on the step - defines the maximum number of retries that the server will accept - for a step. For this orchestrator, this controls how often the - job running the step will try to start the step pod. There are some - circumstances however where the job will start the pod, but the pod - doesn't actually get to the point of running the step. That means - the server will not receive the maximum amount of retry requests, - which in turn causes other inconsistencies like wrong step statuses. - To mitigate this, this attribute allows to add a margin to the - backoff limit. This means that the job will retry the pod startup - for the configured amount of times plus the margin, which increases - the chance of the server receiving the maximum amount of retry - requests. - pod_failure_policy: The pod failure policy to use for the job that is - executing the step. - prevent_orchestrator_pod_caching: If `True`, the orchestrator pod will - not try to compute cached steps before starting the step pods. - always_build_pipeline_image: If `True`, the orchestrator will always - build the pipeline image, even if all steps have a custom build. - pod_stop_grace_period: When stopping a pipeline run, the amount of - seconds to wait for a step pod to shutdown gracefully. + Configuration options for how pipelines are executed on Kubernetes clusters. + Field descriptions are defined inline using Field() descriptors. """ - synchronous: bool = True - timeout: int = 0 - stream_step_logs: bool = True - service_account_name: Optional[str] = None - step_pod_service_account_name: Optional[str] = None - privileged: bool = False - pod_settings: Optional[KubernetesPodSettings] = None - orchestrator_pod_settings: Optional[KubernetesPodSettings] = None - pod_name_prefix: Optional[str] = None - pod_startup_timeout: int = 60 * 10 # Default 10 minutes - pod_failure_max_retries: int = 3 - pod_failure_retry_delay: int = 10 - pod_failure_backoff: float = 1.0 - max_parallelism: Optional[PositiveInt] = None - successful_jobs_history_limit: Optional[NonNegativeInt] = None - failed_jobs_history_limit: Optional[NonNegativeInt] = None - ttl_seconds_after_finished: Optional[NonNegativeInt] = None - active_deadline_seconds: Optional[NonNegativeInt] = None - backoff_limit_margin: NonNegativeInt = 0 - pod_failure_policy: Optional[Dict[str, Any]] = None - prevent_orchestrator_pod_caching: bool = False - always_build_pipeline_image: bool = False - pod_stop_grace_period: PositiveInt = 30 + synchronous: bool = Field( + default=True, + description="Whether to wait for all pipeline steps to complete. " + "When `False`, the client returns immediately and execution continues asynchronously.", + ) + timeout: int = Field( + default=0, + description="Maximum seconds to wait for synchronous runs. Set to `0` for unlimited duration.", + ) + stream_step_logs: bool = Field( + default=True, + description="If `True`, the orchestrator pod will stream the logs " + "of the step pods. This only has an effect if specified on the " + "pipeline, not on individual steps.", + ) + service_account_name: Optional[str] = Field( + default=None, + description="Kubernetes service account for the orchestrator pod. " + "If not specified, creates a new account with 'edit' permissions.", + ) + step_pod_service_account_name: Optional[str] = Field( + default=None, + description="Kubernetes service account for step execution pods. " + "Uses the default service account if not specified.", + ) + privileged: bool = Field( + default=False, + description="Whether to run containers in privileged mode with extended permissions.", + ) + pod_settings: Optional[KubernetesPodSettings] = Field( + default=None, + description="Pod configuration for step execution containers.", + ) + orchestrator_pod_settings: Optional[KubernetesPodSettings] = Field( + default=None, + description="Pod configuration for the orchestrator container that launches step pods.", + ) + pod_name_prefix: Optional[str] = Field( + default=None, + description="Custom prefix for generated pod names. Helps identify pods in the cluster.", + ) + pod_startup_timeout: int = Field( + default=600, + description="Maximum seconds to wait for step pods to start. Default is 10 minutes.", + ) + pod_failure_max_retries: int = Field( + default=3, + description="Maximum retry attempts when step pods fail to start.", + ) + pod_failure_retry_delay: int = Field( + default=10, + description="Delay in seconds between pod failure retry attempts.", + ) + pod_failure_backoff: float = Field( + default=1.0, + description="Exponential backoff factor for retry delays. Values > 1.0 increase delay with each retry.", + ) + max_parallelism: Optional[PositiveInt] = Field( + default=None, + description="Maximum number of step pods to run concurrently. No limit if not specified.", + ) + successful_jobs_history_limit: Optional[NonNegativeInt] = Field( + default=None, + description="Number of successful scheduled jobs to retain in cluster history.", + ) + failed_jobs_history_limit: Optional[NonNegativeInt] = Field( + default=None, + description="Number of failed scheduled jobs to retain in cluster history.", + ) + ttl_seconds_after_finished: Optional[NonNegativeInt] = Field( + default=None, + description="Seconds to keep finished scheduled jobs before automatic cleanup.", + ) + active_deadline_seconds: Optional[NonNegativeInt] = Field( + default=None, + description="Deadline in seconds for the active pod. If the pod is inactive for this many seconds, it will be terminated.", + ) + backoff_limit_margin: NonNegativeInt = Field( + default=0, + description="The value to add to the backoff limit in addition " + "to the step retries. The retry configuration defined on the step " + "defines the maximum number of retries that the server will accept " + "for a step. For this orchestrator, this controls how often the " + "job running the step will try to start the step pod. There are some " + "circumstances however where the job will start the pod, but the pod " + "doesn't actually get to the point of running the step. That means " + "the server will not receive the maximum amount of retry requests, " + "which in turn causes other inconsistencies like wrong step statuses. " + "To mitigate this, this attribute allows to add a margin to the " + "backoff limit. This means that the job will retry the pod startup " + "for the configured amount of times plus the margin, which increases " + "the chance of the server receiving the maximum amount of retry " + "requests." + ) + pod_failure_policy: Optional[Dict[str, Any]] = Field( + default=None, + description="The pod failure policy to use for the job that is " + "executing the step.", + ) + prevent_orchestrator_pod_caching: bool = Field( + default=False, + description="Whether to disable caching optimization in the orchestrator pod.", + ) + always_build_pipeline_image: bool = Field( + default=False, + description="If `True`, the orchestrator will always build the pipeline image, " + "even if all steps have a custom build.", + ) + pod_stop_grace_period: PositiveInt = Field( + default=30, + description="When stopping a pipeline run, the amount of seconds to wait for a step pod to shutdown gracefully.", + ) @field_validator("pod_failure_policy", mode="before") @classmethod @@ -144,42 +174,50 @@ def _convert_pod_failure_policy(cls, value: Any) -> Any: class KubernetesOrchestratorConfig( BaseOrchestratorConfig, KubernetesOrchestratorSettings ): - """Configuration for the Kubernetes orchestrator. - - Attributes: - incluster: If `True`, the orchestrator will run the pipeline inside the - same cluster in which it itself is running. This requires the client - to run in a Kubernetes pod itself. If set, the `kubernetes_context` - config option is ignored. If the stack component is linked to a - Kubernetes service connector, this field is ignored. - kubernetes_context: Name of a Kubernetes context to run pipelines in. - If the stack component is linked to a Kubernetes service connector, - this field is ignored. Otherwise, it is mandatory. - kubernetes_namespace: Name of the Kubernetes namespace to be used. - If not provided, `zenml` namespace will be used. - local: If `True`, the orchestrator will assume it is connected to a - local kubernetes cluster and will perform additional validations and - operations to allow using the orchestrator in combination with other - local stack components that store data in the local filesystem - (i.e. it will mount the local stores directory into the pipeline - containers). - skip_local_validations: If `True`, the local validations will be - skipped. - parallel_step_startup_waiting_period: How long to wait in between - starting parallel steps. This can be used to distribute server - load when running pipelines with a huge amount of parallel steps. - pass_zenml_token_as_secret: If `True`, the ZenML token will be passed - as a Kubernetes secret to the pods. For this to work, the Kubernetes - client must have permissions to create secrets in the namespace. - """ - - incluster: bool = False - kubernetes_context: Optional[str] = None - kubernetes_namespace: str = "zenml" - local: bool = False - skip_local_validations: bool = False - parallel_step_startup_waiting_period: Optional[float] = None - pass_zenml_token_as_secret: bool = False + """Configuration for the Kubernetes orchestrator.""" + + incluster: bool = Field( + False, + description="If `True`, the orchestrator will run the pipeline inside the " + "same cluster in which it itself is running. This requires the client " + "to run in a Kubernetes pod itself. If set, the `kubernetes_context` " + "config option is ignored. If the stack component is linked to a " + "Kubernetes service connector, this field is ignored.", + ) + kubernetes_context: Optional[str] = Field( + None, + description="Name of a Kubernetes context to run pipelines in. " + "If the stack component is linked to a Kubernetes service connector, " + "this field is ignored. Otherwise, it is mandatory.", + ) + kubernetes_namespace: str = Field( + "zenml", + description="Name of the Kubernetes namespace to be used. " + "If not provided, `zenml` namespace will be used.", + ) + local: bool = Field( + False, + description="If `True`, the orchestrator will assume it is connected to a " + "local kubernetes cluster and will perform additional validations and " + "operations to allow using the orchestrator in combination with other " + "local stack components that store data in the local filesystem " + "(i.e. it will mount the local stores directory into the pipeline containers).", + ) + skip_local_validations: bool = Field( + False, description="If `True`, the local validations will be skipped." + ) + parallel_step_startup_waiting_period: Optional[float] = Field( + None, + description="How long to wait in between starting parallel steps. " + "This can be used to distribute server load when running pipelines " + "with a huge amount of parallel steps.", + ) + pass_zenml_token_as_secret: bool = Field( + False, + description="If `True`, the ZenML token will be passed as a Kubernetes secret " + "to the pods. For this to work, the Kubernetes client must have permissions " + "to create secrets in the namespace.", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py b/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py index 7e28ad1d40a..d61f9f30afd 100644 --- a/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +++ b/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE from zenml.integrations.kubernetes import KUBERNETES_STEP_OPERATOR_FLAVOR @@ -31,27 +33,38 @@ class KubernetesStepOperatorSettings(BaseSettings): """Settings for the Kubernetes step operator. - Attributes: - pod_settings: Pod settings to apply to pods executing the steps. - service_account_name: Name of the service account to use for the pod. - privileged: If the container should be run in privileged mode. - pod_startup_timeout: The maximum time to wait for a pending step pod to - start (in seconds). - pod_failure_max_retries: The maximum number of times to retry a step - pod if the step Kubernetes pod fails to start - pod_failure_retry_delay: The delay in seconds between pod - failure retries and pod startup retries (in seconds) - pod_failure_backoff: The backoff factor for pod failure retries and - pod startup retries. + Configuration options for individual step execution on Kubernetes. + Field descriptions are defined inline using Field() descriptors. """ - pod_settings: Optional[KubernetesPodSettings] = None - service_account_name: Optional[str] = None - privileged: bool = False - pod_startup_timeout: int = 60 * 10 # Default 10 minutes - pod_failure_max_retries: int = 3 - pod_failure_retry_delay: int = 10 - pod_failure_backoff: float = 1.0 + pod_settings: Optional[KubernetesPodSettings] = Field( + default=None, + description="Pod configuration for step execution containers.", + ) + service_account_name: Optional[str] = Field( + default=None, + description="Kubernetes service account for step pods. Uses default account if not specified.", + ) + privileged: bool = Field( + default=False, + description="Whether to run step containers in privileged mode with extended permissions.", + ) + pod_startup_timeout: int = Field( + default=600, + description="Maximum seconds to wait for step pods to start. Default is 10 minutes.", + ) + pod_failure_max_retries: int = Field( + default=3, + description="Maximum retry attempts when step pods fail to start.", + ) + pod_failure_retry_delay: int = Field( + default=10, + description="Delay in seconds between pod failure retry attempts.", + ) + pod_failure_backoff: float = Field( + default=1.0, + description="Exponential backoff factor for retry delays. Values > 1.0 increase delay with each retry.", + ) class KubernetesStepOperatorConfig( @@ -59,22 +72,24 @@ class KubernetesStepOperatorConfig( ): """Configuration for the Kubernetes step operator. - Attributes: - kubernetes_namespace: Name of the Kubernetes namespace to be used. - incluster: If `True`, the step operator will run the pipeline inside the - same cluster in which the orchestrator is running. For this to work, - the pod running the orchestrator needs permissions to create new - pods. If set, the `kubernetes_context` config option is ignored. If - the stack component is linked to a Kubernetes service connector, - this field is ignored. - kubernetes_context: Name of a Kubernetes context to run pipelines in. - If the stack component is linked to a Kubernetes service connector, - this field is ignored. Otherwise, it is mandatory. + Defines cluster connection and execution settings. + Field descriptions are defined inline using Field() descriptors. """ - kubernetes_namespace: str = "zenml" - incluster: bool = False - kubernetes_context: Optional[str] = None + kubernetes_namespace: str = Field( + default="zenml", + description="Kubernetes namespace for step execution. Must be a valid namespace name.", + ) + incluster: bool = Field( + default=False, + description="Whether to execute within the same cluster as the orchestrator. " + "Requires appropriate pod creation permissions.", + ) + kubernetes_context: Optional[str] = Field( + default=None, + description="Kubernetes context name for cluster connection. " + "Ignored when using service connectors or in-cluster execution.", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py index bf66791123e..425b5616f45 100644 --- a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +++ b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, List, Optional, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.integrations.lightning import LIGHTNING_ORCHESTRATOR_FLAVOR from zenml.logger import get_logger @@ -33,33 +35,47 @@ class LightningOrchestratorSettings(BaseSettings): """Lightning orchestrator base settings. - Attributes: - main_studio_name: Main studio name. - machine_type: Machine type. - user_id: User id. - api_key: api_key. - username: Username. - teamspace: Teamspace. - organization: Organization. - custom_commands: Custom commands to run. - synchronous: If `True`, the client running a pipeline using this - orchestrator waits until all steps finish running. If `False`, - the client returns immediately and the pipeline is executed - asynchronously. Defaults to `True`. This setting only - has an effect when specified on the pipeline and will be ignored if - specified on steps. + Configuration for executing pipelines on Lightning AI platform. + Field descriptions are defined inline using Field() descriptors. """ - # Resources - main_studio_name: Optional[str] = None - machine_type: Optional[str] = None - user_id: Optional[str] = SecretField(default=None) - api_key: Optional[str] = SecretField(default=None) - username: Optional[str] = None - teamspace: Optional[str] = None - organization: Optional[str] = None - custom_commands: Optional[List[str]] = None - synchronous: bool = True + # Lightning AI Platform Configuration + main_studio_name: Optional[str] = Field( + default=None, + description="Lightning AI studio instance name where the pipeline will execute.", + ) + machine_type: Optional[str] = Field( + default=None, + description="Compute instance type for pipeline execution. " + "Refer to Lightning AI documentation for available options.", + ) + user_id: Optional[str] = SecretField( + default=None, description="Lightning AI user ID for authentication." + ) + api_key: Optional[str] = SecretField( + default=None, + description="Lightning AI API key for platform authentication.", + ) + username: Optional[str] = Field( + default=None, description="Lightning AI platform username." + ) + teamspace: Optional[str] = Field( + default=None, + description="Lightning AI teamspace for collaborative pipeline execution.", + ) + organization: Optional[str] = Field( + default=None, + description="Lightning AI organization name for enterprise accounts.", + ) + custom_commands: Optional[List[str]] = Field( + default=None, + description="Additional shell commands to execute in the Lightning AI environment.", + ) + synchronous: bool = Field( + default=True, + description="Whether to wait for pipeline completion. " + "When `False`, execution continues asynchronously after submission.", + ) class LightningOrchestratorConfig( diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py index 368820199b3..13160f7fe7e 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Type -from pydantic import model_validator +from pydantic import Field, model_validator from zenml.config.base_settings import BaseSettings from zenml.experiment_trackers.base_experiment_tracker import ( @@ -59,56 +59,63 @@ def is_databricks_tracking_uri(tracking_uri: str) -> bool: class MLFlowExperimentTrackerSettings(BaseSettings): - """Settings for the MLflow experiment tracker. + """Settings for the MLflow experiment tracker.""" - Attributes: - experiment_name: The MLflow experiment name. - nested: If `True`, will create a nested sub-run for the step. - tags: Tags for the Mlflow run. - """ - - experiment_name: Optional[str] = None - nested: bool = False - tags: Dict[str, Any] = {} + experiment_name: Optional[str] = Field( + None, + description="The MLflow experiment name to use for tracking runs.", + ) + nested: bool = Field( + False, + description="If `True`, will create a nested sub-run for the step.", + ) + tags: Dict[str, Any] = Field( + default_factory=dict, + description="Tags to attach to the MLflow run for categorization and filtering.", + ) class MLFlowExperimentTrackerConfig( BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings ): - """Config for the MLflow experiment tracker. - - Attributes: - tracking_uri: The uri of the mlflow tracking server. If no uri is set, - your stack must contain a `LocalArtifactStore` and ZenML will - point MLflow to a subdirectory of your artifact store instead. - tracking_username: Username for authenticating with the MLflow - tracking server. When a remote tracking uri is specified, - either `tracking_token` or `tracking_username` and - `tracking_password` must be specified. - tracking_password: Password for authenticating with the MLflow - tracking server. When a remote tracking uri is specified, - either `tracking_token` or `tracking_username` and - `tracking_password` must be specified. - tracking_token: Token for authenticating with the MLflow - tracking server. When a remote tracking uri is specified, - either `tracking_token` or `tracking_username` and - `tracking_password` must be specified. - tracking_insecure_tls: Skips verification of TLS connection to the - MLflow tracking server if set to `True`. - databricks_host: The host of the Databricks workspace with the MLflow - managed server to connect to. This is only required if - `tracking_uri` value is set to `"databricks"`. - enable_unity_catalog: If `True`, will enable the Databricks Unity Catalog for - logging and registering models. - """ + """Config for the MLflow experiment tracker.""" - tracking_uri: Optional[str] = None - tracking_username: Optional[str] = SecretField(default=None) - tracking_password: Optional[str] = SecretField(default=None) - tracking_token: Optional[str] = SecretField(default=None) - tracking_insecure_tls: bool = False - databricks_host: Optional[str] = None - enable_unity_catalog: bool = False + tracking_uri: Optional[str] = Field( + None, + description="The URI of the MLflow tracking server. If no URI is set, " + "your stack must contain a LocalArtifactStore and ZenML will point " + "MLflow to a subdirectory of your artifact store instead.", + ) + tracking_username: Optional[str] = SecretField( + default=None, + description="Username for authenticating with the MLflow tracking server. " + "Required when using a remote tracking URI along with tracking_password.", + ) + tracking_password: Optional[str] = SecretField( + default=None, + description="Password for authenticating with the MLflow tracking server. " + "Required when using a remote tracking URI along with tracking_username.", + ) + tracking_token: Optional[str] = SecretField( + default=None, + description="Token for authenticating with the MLflow tracking server. " + "Alternative to username/password authentication for remote tracking URIs.", + ) + tracking_insecure_tls: bool = Field( + False, + description="Skips verification of TLS connection to the MLflow tracking " + "server if set to `True`. Use with caution in production environments.", + ) + databricks_host: Optional[str] = Field( + None, + description="The host of the Databricks workspace with the MLflow managed " + "server to connect to. Required when tracking_uri is set to 'databricks'.", + ) + enable_unity_catalog: bool = Field( + False, + description="If `True`, will enable the Databricks Unity Catalog for " + "logging and registering models.", + ) @model_validator(mode="after") def _ensure_authentication_if_necessary( diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py index 7de40370465..483304e6548 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.integrations.mlflow import MLFLOW_MODEL_DEPLOYER_FLAVOR from zenml.model_deployers.base_model_deployer import ( BaseModelDeployerConfig, @@ -28,12 +30,15 @@ class MLFlowModelDeployerConfig(BaseModelDeployerConfig): """Configuration for the MLflow model deployer. - Attributes: - service_path: the path where the local MLflow deployment service - configuration, PID and log files are stored. + Configuration for local MLflow model serving. + Field descriptions are defined inline using Field() descriptors. """ - service_path: str = "" + service_path: str = Field( + default="", + description="Local directory for MLflow deployment service files " + "(configuration, PID, and logs). Uses temp directory if empty.", + ) @property def is_local(self) -> bool: diff --git a/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py b/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py index ba5fe90422a..dd10c6a2dbe 100644 --- a/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +++ b/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py @@ -21,6 +21,8 @@ from typing import TYPE_CHECKING, Optional, Set, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.experiment_trackers.base_experiment_tracker import ( BaseExperimentTrackerConfig, @@ -40,24 +42,23 @@ class NeptuneExperimentTrackerConfig(BaseExperimentTrackerConfig): If attributes are left as None, the neptune.init_run() method will try to find the relevant values in the environment - - Attributes: - project: name of the Neptune project you want to log the metadata to - api_token: your Neptune API token """ - project: Optional[str] = None - api_token: Optional[str] = SecretField(default=None) + project: Optional[str] = Field( + None, + description="Name of the Neptune project you want to log the metadata to.", + ) + api_token: Optional[str] = SecretField( + default=None, description="Your Neptune API token for authentication." + ) class NeptuneExperimentTrackerSettings(BaseSettings): - """Settings for the Neptune experiment tracker. + """Settings for the Neptune experiment tracker.""" - Attributes: - tags: Tags for the Neptune run. - """ - - tags: Set[str] = set() + tags: Set[str] = Field( + default_factory=set, description="Tags for the Neptune run." + ) class NeptuneExperimentTrackerFlavor(BaseExperimentTrackerFlavor): diff --git a/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py b/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py index dabdd9bb384..a956c47fd7a 100644 --- a/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +++ b/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py @@ -24,7 +24,7 @@ Type, ) -from pydantic import field_validator +from pydantic import Field, field_validator from zenml.artifact_stores import ( BaseArtifactStoreConfig, @@ -64,12 +64,37 @@ class S3ArtifactStoreConfig( SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"s3://"} - key: Optional[str] = SecretField(default=None) - secret: Optional[str] = SecretField(default=None) - token: Optional[str] = SecretField(default=None) - client_kwargs: Optional[Dict[str, Any]] = None - config_kwargs: Optional[Dict[str, Any]] = None - s3_additional_kwargs: Optional[Dict[str, Any]] = None + key: Optional[str] = SecretField( + default=None, + description="AWS access key ID for authentication. " + "If not provided, credentials will be inferred from the environment.", + ) + secret: Optional[str] = SecretField( + default=None, + description="AWS secret access key for authentication. " + "If not provided, credentials will be inferred from the environment.", + ) + token: Optional[str] = SecretField( + default=None, + description="AWS session token for temporary credentials. " + "If not provided, credentials will be inferred from the environment.", + ) + client_kwargs: Optional[Dict[str, Any]] = Field( + None, + description="Additional keyword arguments to pass to the S3 client. " + "For example, to connect to a custom S3-compatible endpoint: " + "{'endpoint_url': 'http://minio:9000'}", + ) + config_kwargs: Optional[Dict[str, Any]] = Field( + None, + description="Additional keyword arguments to pass to the S3 client configuration. " + "For example: {'region_name': 'us-west-2', 'signature_version': 's3v4'}", + ) + s3_additional_kwargs: Optional[Dict[str, Any]] = Field( + None, + description="Additional keyword arguments for S3 operations. " + "For example: {'ACL': 'bucket-owner-full-control'}", + ) _bucket: Optional[str] = None diff --git a/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py b/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py index dd8a2bc8a11..c21c76fcff8 100644 --- a/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py +++ b/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.integrations.vllm import VLLM_MODEL_DEPLOYER from zenml.model_deployers.base_model_deployer import ( BaseModelDeployerConfig, @@ -28,7 +30,11 @@ class VLLMModelDeployerConfig(BaseModelDeployerConfig): """Configuration for vLLM Inference model deployer.""" - service_path: str = "" + service_path: str = Field( + "", + description="The path where the local vLLM deployment service " + "configuration, PID and log files are stored.", + ) class VLLMModelDeployerFlavor(BaseModelDeployerFlavor): diff --git a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py index 96aa8f0e20b..da8b29ff662 100644 --- a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +++ b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py @@ -23,7 +23,7 @@ cast, ) -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from zenml.config.base_settings import BaseSettings from zenml.experiment_trackers.base_experiment_tracker import ( @@ -40,19 +40,24 @@ class WandbExperimentTrackerSettings(BaseSettings): - """Settings for the Wandb experiment tracker. + """Settings for the Wandb experiment tracker.""" - Attributes: - run_name: The Wandb run name. - tags: Tags for the Wandb run. - settings: Settings for the Wandb run. - enable_weave: Whether to enable Weave integration. - """ - - run_name: Optional[str] = None - tags: List[str] = [] - settings: Dict[str, Any] = {} - enable_weave: bool = False + run_name: Optional[str] = Field( + None, + description="The Wandb run name to use for tracking experiments." + ) + tags: List[str] = Field( + default_factory=list, + description="Tags to attach to the Wandb run for categorization and filtering." + ) + settings: Dict[str, Any] = Field( + default_factory=dict, + description="Additional settings for the Wandb run configuration." + ) + enable_weave: bool = Field( + False, + description="Whether to enable Weave integration for enhanced experiment tracking." + ) @field_validator("settings", mode="before") @classmethod @@ -89,18 +94,22 @@ def _convert_settings(cls, value: Any) -> Any: class WandbExperimentTrackerConfig( BaseExperimentTrackerConfig, WandbExperimentTrackerSettings ): - """Config for the Wandb experiment tracker. - - Attributes: - entity: Name of an existing wandb entity. - project_name: Name of an existing wandb project to log to. - api_key: API key to should be authorized to log to the configured wandb - entity and project. - """ - - api_key: str = SecretField() - entity: Optional[str] = None - project_name: Optional[str] = None + """Config for the Wandb experiment tracker.""" + + api_key: str = SecretField( + description="API key that should be authorized to log to the configured " + "Wandb entity and project. Required for authentication." + ) + entity: Optional[str] = Field( + None, + description="Name of an existing Wandb entity (team or user account) " + "to log experiments to." + ) + project_name: Optional[str] = Field( + None, + description="Name of an existing Wandb project to log experiments to. " + "If not specified, a default project will be used." + ) class WandbExperimentTrackerFlavor(BaseExperimentTrackerFlavor): diff --git a/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py b/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py index f7c5f025234..1696f7779c6 100644 --- a/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +++ b/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.data_validators.base_data_validator import ( BaseDataValidatorConfig, @@ -28,17 +30,18 @@ class WhylogsDataValidatorSettings(BaseSettings): - """Settings for the Whylogs data validator. - - Attributes: - enable_whylabs: If set to `True` for a step, all the whylogs data - profile views returned by the step will automatically be uploaded - to the Whylabs platform if Whylabs credentials are configured. - dataset_id: Dataset ID to use when uploading profiles to Whylabs. - """ - - enable_whylabs: bool = False - dataset_id: Optional[str] = None + """Settings for the Whylogs data validator.""" + + enable_whylabs: bool = Field( + False, + description="If set to `True` for a step, all the whylogs data " + "profile views returned by the step will automatically be uploaded " + "to the Whylabs platform if Whylabs credentials are configured.", + ) + dataset_id: Optional[str] = Field( + None, + description="Dataset ID to use when uploading profiles to Whylabs.", + ) class WhylogsDataValidatorConfig( diff --git a/src/zenml/stack/authentication_mixin.py b/src/zenml/stack/authentication_mixin.py index 0d3a35e3fa1..6e4024d78b3 100644 --- a/src/zenml/stack/authentication_mixin.py +++ b/src/zenml/stack/authentication_mixin.py @@ -15,7 +15,7 @@ from typing import Optional, Type, TypeVar, cast -from pydantic import BaseModel +from pydantic import BaseModel, Field from zenml.client import Client from zenml.models import SecretResponse @@ -30,12 +30,13 @@ class AuthenticationConfigMixin(StackComponentConfig): Any stack component that implements `AuthenticationMixin` should have a config that inherits from this class. - Attributes: - authentication_secret: Name of the secret that stores the - authentication credentials. + Field descriptions are defined inline using Field() descriptors. """ - authentication_secret: Optional[str] = None + authentication_secret: Optional[str] = Field( + default=None, + description="Name of the ZenML secret containing authentication credentials.", + ) class AuthenticationMixin(StackComponent): diff --git a/src/zenml/stack/flavor.py b/src/zenml/stack/flavor.py index 0d861e70e21..f80b7f52936 100644 --- a/src/zenml/stack/flavor.py +++ b/src/zenml/stack/flavor.py @@ -248,9 +248,13 @@ def generate_default_sdk_docs_url(self) -> str: "zenml.integrations.", maxsplit=1 )[1].split(".")[0] + # Get the config class name to point to the specific class + config_class_name = self.config_class.__name__ + return ( f"{base}/integration_code_docs" - f"/integrations-{integration}/#{self.__module__}" + f"/integrations-{integration}" + f"#zenml.integrations.{integration}.flavors.{config_class_name}" ) else: diff --git a/tests/unit/test_flavor.py b/tests/unit/test_flavor.py index 11867eded0d..868b260f052 100644 --- a/tests/unit/test_flavor.py +++ b/tests/unit/test_flavor.py @@ -70,3 +70,18 @@ def test_docs_url(): assert AriaOrchestratorFlavor().docs_url == ( "https://docs.zenml.io/stack-components/orchestrators/aria" ) + + +def test_integration_sdk_docs_url(): + """Tests that integration SDK Docs URLs point to specific config classes.""" + from zenml.integrations.kubernetes.flavors.kubernetes_orchestrator_flavor import ( + KubernetesOrchestratorFlavor, + ) + + flavor = KubernetesOrchestratorFlavor() + expected_url = ( + f"https://sdkdocs.zenml.io/{zenml_version}/integration_code_docs" + f"/integrations-kubernetes" + f"#zenml.integrations.kubernetes.flavors.KubernetesOrchestratorConfig" + ) + assert flavor.sdk_docs_url == expected_url