diff --git a/docs/book/reference/legacy-docs.md b/docs/book/reference/legacy-docs.md index 7fd38944d35..4b2991b9583 100644 --- a/docs/book/reference/legacy-docs.md +++ b/docs/book/reference/legacy-docs.md @@ -5,4 +5,4 @@ icon: user-bounty-hunter # Legacy docs -
0.83.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.1/
0.83.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.0/
0.82.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.82.0/
0.81.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.81.0/
0.80.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.2/
0.80.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.1/
0.80.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.0/
0.75.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.75.0/
0.74.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.74.0/
0.73.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.73.0/
0.72.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.72.0
0.71.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.71.0
0.70.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.70.0
0.68.1https://zenml-io.gitbook.io/zenml-legacy-documentation/0.68.1
0.68.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.68.0/
0.67.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/IAT17KUXNgudlhTyCPn4/
0.66.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.66.0
0.65.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.65.0/
0.64.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.64.0/
0.63.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/N7uiChp9LXuqoNrLj3U3/
0.62.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.62.0/
0.61.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.61.0
0.60.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.60.0/
0.58.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.2/
0.58.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.1/
0.58.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.0
0.57.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.1
0.57.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.0
0.56.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.4
0.56.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.3
0.56.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.2
0.56.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.1
0.55.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.5
0.55.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.4
0.55.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.3
0.55.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.2
0.55.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.1
0.55.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.0
0.54.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.1
0.54.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.0
0.53.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.1
0.53.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.0
0.52.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.52.0
0.51.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.51.0
0.50.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.50.0
0.47.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.47.0-legacy
0.46.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.1-legacy
0.46.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.0-legacy
0.45.6https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.6-legacy
0.45.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.5-legacy
0.45.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.4-legacy
0.45.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.3-legacy
0.45.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.2-legacy
0.44.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.4-legacy
0.44.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.3-legacy
0.44.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.2-legacy
0.44.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.1-legacy
0.43.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.1-legacy
0.43.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.0-legacy
0.42.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.2-legacy
0.42.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.1-legacy
0.42.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.0-legacy
0.41.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.41.0-legacy
0.40.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.3-legacy
0.40.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.2-legacy
0.40.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.1-legacy
0.40.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.0-legacy
0.39.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.1-legacy
0.39.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.0-legacy
0.38.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.38.0-legacy
0.37.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.37.0-legacy
0.36.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.1-legacy
0.36.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.0-legacy
0.35.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.1-legacy
0.35.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.0-legacy
0.34.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.34.0-legacy
0.33.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.33.0-legacy
0.32.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.1-legacy
0.32.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.0-legacy
0.31.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.1-legacy
0.31.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.0-legacy
0.30.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.30.0-legacy
0.20.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.5-legacy
0.20.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.4-legacy
0.20.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.3-legacy
0.20.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.2-legacy
0.20.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.0-legacy
0.13.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.2
0.13.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.1
0.13.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.0
0.12.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.12.0
0.11.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.11.0
0.10.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.10.0
+
0.84.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.0/
0.83.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.1/
0.83.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.0/
0.82.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.82.0/
0.81.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.81.0/
0.80.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.2/
0.80.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.1/
0.80.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.0/
0.75.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.75.0/
0.74.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.74.0/
0.73.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.73.0/
0.72.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.72.0
0.71.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.71.0
0.70.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.70.0
0.68.1https://zenml-io.gitbook.io/zenml-legacy-documentation/0.68.1
0.68.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.68.0/
0.67.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/IAT17KUXNgudlhTyCPn4/
0.66.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.66.0
0.65.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.65.0/
0.64.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.64.0/
0.63.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/N7uiChp9LXuqoNrLj3U3/
0.62.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.62.0/
0.61.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.61.0
0.60.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.60.0/
0.58.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.2/
0.58.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.1/
0.58.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.0
0.57.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.1
0.57.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.0
0.56.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.4
0.56.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.3
0.56.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.2
0.56.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.1
0.55.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.5
0.55.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.4
0.55.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.3
0.55.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.2
0.55.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.1
0.55.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.0
0.54.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.1
0.54.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.0
0.53.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.1
0.53.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.0
0.52.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.52.0
0.51.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.51.0
0.50.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.50.0
0.47.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.47.0-legacy
0.46.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.1-legacy
0.46.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.0-legacy
0.45.6https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.6-legacy
0.45.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.5-legacy
0.45.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.4-legacy
0.45.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.3-legacy
0.45.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.2-legacy
0.44.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.4-legacy
0.44.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.3-legacy
0.44.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.2-legacy
0.44.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.1-legacy
0.43.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.1-legacy
0.43.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.0-legacy
0.42.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.2-legacy
0.42.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.1-legacy
0.42.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.0-legacy
0.41.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.41.0-legacy
0.40.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.3-legacy
0.40.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.2-legacy
0.40.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.1-legacy
0.40.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.0-legacy
0.39.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.1-legacy
0.39.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.0-legacy
0.38.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.38.0-legacy
0.37.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.37.0-legacy
0.36.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.1-legacy
0.36.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.0-legacy
0.35.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.1-legacy
0.35.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.0-legacy
0.34.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.34.0-legacy
0.33.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.33.0-legacy
0.32.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.1-legacy
0.32.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.0-legacy
0.31.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.1-legacy
0.31.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.0-legacy
0.30.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.30.0-legacy
0.20.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.5-legacy
0.20.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.4-legacy
0.20.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.3-legacy
0.20.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.2-legacy
0.20.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.0-legacy
0.13.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.2
0.13.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.1
0.13.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.0
0.12.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.12.0
0.11.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.11.0
0.10.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.10.0
diff --git a/src/zenml/integrations/kubeflow/__init__.py b/src/zenml/integrations/kubeflow/__init__.py index 600c7e587a6..3ebe15698a9 100644 --- a/src/zenml/integrations/kubeflow/__init__.py +++ b/src/zenml/integrations/kubeflow/__init__.py @@ -30,7 +30,10 @@ class KubeflowIntegration(Integration): """Definition of Kubeflow Integration for ZenML.""" NAME = KUBEFLOW - REQUIREMENTS = ["kfp>=2.6.0", "kfp-kubernetes>=1.1.0"] # Only 1.x version that supports pyyaml 6 + REQUIREMENTS = [ + "kfp>=2.6.0", + "kfp-kubernetes>=1.1.0", + ] REQUIREMENTS_IGNORED_ON_UNINSTALL = [ "kfp", # it is used by GCP as well ] @@ -45,7 +48,10 @@ def flavors(cls) -> List[Type[Flavor]]: from zenml.integrations.kubeflow.flavors import ( KubeflowOrchestratorFlavor, ) + from zenml.integrations.kubeflow.flavors.kubeflow_training_step_operator_flavor import ( + KubeflowTrainingStepOperatorFlavor, + ) - return [KubeflowOrchestratorFlavor] + return [KubeflowOrchestratorFlavor, KubeflowTrainingStepOperatorFlavor] diff --git a/src/zenml/integrations/kubeflow/flavors/kubeflow_training_step_operator_flavor.py b/src/zenml/integrations/kubeflow/flavors/kubeflow_training_step_operator_flavor.py new file mode 100644 index 00000000000..5b2d191775f --- /dev/null +++ b/src/zenml/integrations/kubeflow/flavors/kubeflow_training_step_operator_flavor.py @@ -0,0 +1,244 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Kubeflow Training step operator flavor.""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, Type + +from pydantic import Field + +from zenml.config.base_settings import BaseSettings +from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE +from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings +from zenml.models import ServiceConnectorRequirements +from zenml.step_operators import BaseStepOperatorConfig, BaseStepOperatorFlavor + +if TYPE_CHECKING: + from zenml.integrations.kubeflow.step_operators import ( + KubeflowTrainingStepOperator, + ) + +KUBEFLOW_TRAINING_STEP_OPERATOR_FLAVOR = "kubeflow_training" + + +class KubeflowTrainingStepOperatorSettings(BaseSettings): + """Settings for the Kubeflow Training step operator. + + Attributes: + training_job_type: Type of training job to create. Supports: + "PyTorch", "TensorFlow", "JAX", "MPI". Defaults to "PyTorch". + num_workers: Number of worker replicas for the training job. + num_ps: Number of parameter server replicas (TFJob only). + restart_policy: Restart policy for the training job pods. + clean_pod_policy: When to clean up the job pods after completion. + ttl_seconds_after_finished: TTL for the training job after completion. + pod_template_spec: Additional pod template specification to apply + to training job pods. + worker_pod_settings: Pod settings to apply to worker pods. + master_pod_settings: Pod settings to apply to master/chief pods. + ps_pod_settings: Pod settings to apply to parameter server pods (TFJob only). + service_account_name: Name of the service account to use for the pods. + pod_startup_timeout: Maximum time to wait for pods to start (in seconds). + job_completion_timeout: Maximum time to wait for job completion (in seconds). + """ + + training_job_type: str = Field( + default="PyTorch", + description="Type of training job to create", + ) + num_workers: int = Field( + default=1, + ge=1, + description="Number of worker replicas", + ) + num_ps: Optional[int] = Field( + default=None, + ge=0, + description="Number of parameter server replicas (TFJob only)", + ) + restart_policy: str = Field( + default="OnFailure", + description="Restart policy for training job pods", + ) + clean_pod_policy: str = Field( + default="All", + description="When to clean up job pods after completion", + ) + ttl_seconds_after_finished: Optional[int] = Field( + default=None, + ge=0, + description="TTL for job cleanup after completion", + ) + pod_template_spec: Optional[Dict[str, Any]] = Field( + default=None, + description="Additional pod template specification", + ) + worker_pod_settings: Optional[KubernetesPodSettings] = Field( + default=None, + description="Pod settings for worker pods", + ) + master_pod_settings: Optional[KubernetesPodSettings] = Field( + default=None, + description="Pod settings for master/chief pods", + ) + ps_pod_settings: Optional[KubernetesPodSettings] = Field( + default=None, + description="Pod settings for parameter server pods", + ) + service_account_name: Optional[str] = Field( + default=None, + description="Service account name for the pods", + ) + pod_startup_timeout: int = Field( + default=600, + ge=1, + description="Maximum time to wait for pods to start (seconds)", + ) + job_completion_timeout: int = Field( + default=3600, + ge=1, + description="Maximum time to wait for job completion (seconds)", + ) + + +class KubeflowTrainingStepOperatorConfig( + BaseStepOperatorConfig, KubeflowTrainingStepOperatorSettings +): + """Configuration for the Kubeflow Training step operator. + + Attributes: + kubeflow_namespace: Name of the Kubernetes namespace where Kubeflow + Training Operator is deployed. Defaults to "kubeflow". + incluster: If `True`, the step operator will run inside the same + cluster. For this to work, the pod running the orchestrator needs + permissions to create training jobs. If set, `kubernetes_context` + is ignored. If the stack component is linked to a Kubernetes + service connector, this field is ignored. + kubernetes_context: Name of a Kubernetes context to use. If the stack + component is linked to a Kubernetes service connector, this field + is ignored. Otherwise, it is mandatory. + """ + + kubeflow_namespace: str = Field( + default="kubeflow", + description="Kubernetes namespace for Kubeflow Training Operator", + ) + incluster: bool = Field( + default=False, + description="Whether to run inside the same cluster", + ) + kubernetes_context: Optional[str] = Field( + default=None, + description="Kubernetes context to use", + ) + + @property + def is_remote(self) -> bool: + """Checks if this stack component is running remotely. + + This designation is used to determine if the stack component can be + used with a local ZenML database or if it requires a remote ZenML + server. + + Returns: + True if this config is for a remote component, False otherwise. + """ + return True + + @property + def is_local(self) -> bool: + """Checks if this stack component is running locally. + + Returns: + True if this config is for a local component, False otherwise. + """ + return False + + +class KubeflowTrainingStepOperatorFlavor(BaseStepOperatorFlavor): + """Kubeflow Training step operator flavor.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return KUBEFLOW_TRAINING_STEP_OPERATOR_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + resource_type=KUBERNETES_CLUSTER_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/kubeflow.png" + + @property + def config_class(self) -> Type[KubeflowTrainingStepOperatorConfig]: + """Returns `KubeflowTrainingStepOperatorConfig` config class. + + Returns: + The config class. + """ + return KubeflowTrainingStepOperatorConfig + + @property + def implementation_class(self) -> Type["KubeflowTrainingStepOperator"]: + """Implementation class for this flavor. + + Returns: + The implementation class. + """ + from zenml.integrations.kubeflow.step_operators import ( + KubeflowTrainingStepOperator, + ) + + return KubeflowTrainingStepOperator diff --git a/src/zenml/integrations/kubeflow/step_operators/__init__.py b/src/zenml/integrations/kubeflow/step_operators/__init__.py new file mode 100644 index 00000000000..cc14eecc5a0 --- /dev/null +++ b/src/zenml/integrations/kubeflow/step_operators/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Kubeflow Training step operators.""" + +from zenml.integrations.kubeflow.step_operators.kubeflow_training_step_operator import ( + KubeflowTrainingStepOperator, +) + +__all__ = [ + "KubeflowTrainingStepOperator", +] \ No newline at end of file diff --git a/src/zenml/integrations/kubeflow/step_operators/kubeflow_training_step_operator.py b/src/zenml/integrations/kubeflow/step_operators/kubeflow_training_step_operator.py new file mode 100644 index 00000000000..2776294594c --- /dev/null +++ b/src/zenml/integrations/kubeflow/step_operators/kubeflow_training_step_operator.py @@ -0,0 +1,451 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Kubeflow Training step operator implementation.""" + +import time +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast + +from kubernetes import client as k8s_client + +from zenml.config.base_settings import BaseSettings +from zenml.config.build_configuration import BuildConfiguration +from zenml.enums import StackComponentType +from zenml.integrations.kubeflow.flavors.kubeflow_training_step_operator_flavor import ( + KubeflowTrainingStepOperatorConfig, + KubeflowTrainingStepOperatorSettings, +) +from zenml.integrations.kubernetes.orchestrators import kube_utils +from zenml.logger import get_logger +from zenml.stack import Stack, StackValidator +from zenml.step_operators import BaseStepOperator + +if TYPE_CHECKING: + from zenml.config.step_run_info import StepRunInfo + from zenml.models import PipelineDeploymentBase + +logger = get_logger(__name__) + +KUBEFLOW_TRAINING_STEP_OPERATOR_DOCKER_IMAGE_KEY = ( + "kubeflow_training_step_operator" +) + +# Supported training job types for Kubeflow Training API v1alpha1 +SUPPORTED_JOB_TYPES = { + "PyTorch": { + "api_version": "trainer.kubeflow.org/v1alpha1", + "kind": "TrainJob", + "framework": "pytorch", + }, + "TensorFlow": { + "api_version": "trainer.kubeflow.org/v1alpha1", + "kind": "TrainJob", + "framework": "tensorflow", + }, + "JAX": { + "api_version": "trainer.kubeflow.org/v1alpha1", + "kind": "TrainJob", + "framework": "jax", + }, + "MPI": { + "api_version": "trainer.kubeflow.org/v1alpha1", + "kind": "TrainJob", + "framework": "mpi", + }, +} + + +class KubeflowTrainingStepOperator(BaseStepOperator): + """Step operator to run training jobs on Kubeflow Training Operator.""" + + _k8s_client: Optional[k8s_client.ApiClient] = None + + @property + def config(self) -> KubeflowTrainingStepOperatorConfig: + """Returns the `KubeflowTrainingStepOperatorConfig` config. + + Returns: + The configuration. + """ + return cast(KubeflowTrainingStepOperatorConfig, self._config) + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Settings class for the Kubeflow Training step operator. + + Returns: + The settings class. + """ + return KubeflowTrainingStepOperatorSettings + + @property + def validator(self) -> Optional[StackValidator]: + """Validates the stack. + + Returns: + A validator that checks that the stack contains a remote container + registry and a remote artifact store. + """ + + def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + if stack.artifact_store.config.is_local: + return False, ( + "The Kubeflow Training step operator runs code remotely and " + "needs to write files into the artifact store, but the " + f"artifact store `{stack.artifact_store.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote artifact store when using the Kubeflow " + "Training step operator." + ) + + container_registry = stack.container_registry + assert container_registry is not None + + if container_registry.config.is_local: + return False, ( + "The Kubeflow Training step operator runs code remotely and " + "needs to push/pull Docker images, but the " + f"container registry `{container_registry.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote container registry when using the " + "Kubeflow Training step operator." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.CONTAINER_REGISTRY, + StackComponentType.IMAGE_BUILDER, + }, + custom_validation_function=_validate_remote_components, + ) + + def get_docker_builds( + self, deployment: "PipelineDeploymentBase" + ) -> List["BuildConfiguration"]: + """Gets the Docker builds required for the component. + + Args: + deployment: The pipeline deployment for which to get the builds. + + Returns: + The required Docker builds. + """ + builds = [] + for step_name, step in deployment.step_configurations.items(): + if step.config.uses_step_operator(self.name): + build = BuildConfiguration( + key=KUBEFLOW_TRAINING_STEP_OPERATOR_DOCKER_IMAGE_KEY, + settings=step.config.docker_settings, + step_name=step_name, + ) + builds.append(build) + + return builds + + def get_kube_client(self) -> k8s_client.ApiClient: + """Get the Kubernetes API client. + + Returns: + The Kubernetes API client. + + Raises: + RuntimeError: If the service connector returns an unexpected client. + """ + if self.config.incluster: + kube_utils.load_kube_config(incluster=True) + self._k8s_client = k8s_client.ApiClient() + return self._k8s_client + + # Refresh the client also if the connector has expired + if self._k8s_client and not self.connector_has_expired(): + return self._k8s_client + + connector = self.get_connector() + if connector: + client = connector.connect() + if not isinstance(client, k8s_client.ApiClient): + raise RuntimeError( + f"Expected a k8s_client.ApiClient while trying to use the " + f"linked connector, but got {type(client)}." + ) + self._k8s_client = client + else: + kube_utils.load_kube_config( + context=self.config.kubernetes_context, + ) + self._k8s_client = k8s_client.ApiClient() + + return self._k8s_client + + @property + def _k8s_custom_objects_api(self) -> k8s_client.CustomObjectsApi: + """Getter for the Kubernetes Custom Objects API client. + + Returns: + The Kubernetes Custom Objects API client. + """ + return k8s_client.CustomObjectsApi(self.get_kube_client()) + + def _build_training_job_manifest( + self, + job_name: str, + image: str, + command: List[str], + environment: Dict[str, str], + settings: KubeflowTrainingStepOperatorSettings, + ) -> Dict[str, any]: + """Build the Kubeflow Training Job manifest for API v1alpha1. + + Args: + job_name: Name of the training job. + image: Docker image to use for training. + command: Command to execute for training. + environment: Environment variables to set. + settings: Step operator settings. + + Returns: + The training job manifest as a dictionary. + + Raises: + ValueError: If the training job type is not supported. + """ + # Map old job type names to new ones for backwards compatibility + job_type_mapping = { + "PyTorchJob": "PyTorch", + "TFJob": "TensorFlow", + "MPIJob": "MPI", + "XGBoostJob": "XGBoost", # if we add XGBoost support later + } + + training_job_type = job_type_mapping.get( + settings.training_job_type, settings.training_job_type + ) + + if training_job_type not in SUPPORTED_JOB_TYPES: + raise ValueError( + f"Unsupported training job type: {settings.training_job_type}. " + f"Supported types: {list(SUPPORTED_JOB_TYPES.keys())}" + ) + + # Build the TrainJob spec for Kubeflow Training API v1alpha1 + job_spec = SUPPORTED_JOB_TYPES[training_job_type] + + # Map framework to runtime name + runtime_mapping = { + "pytorch": "torch-distributed", + "tensorflow": "torch-distributed", # fallback to torch + "jax": "torch-distributed", # fallback to torch + "mpi": "mpi-distributed", + } + + runtime_name = runtime_mapping.get( + job_spec["framework"], "torch-distributed" + ) + + train_spec = { + "runtimeRef": { + "name": runtime_name, + }, + "trainer": { + "image": image, + "command": command, + "env": [ + {"name": k, "value": v} for k, v in environment.items() + ], + }, + } + + # Add resource requirements if specified + if ( + settings.worker_pod_settings + and settings.worker_pod_settings.resources + ): + train_spec["trainer"]["resources"] = ( + settings.worker_pod_settings.resources + ) + + # Build the complete manifest for TrainJob + manifest = { + "apiVersion": job_spec["api_version"], + "kind": job_spec["kind"], + "metadata": { + "name": job_name, + "namespace": self.config.kubeflow_namespace, + }, + "spec": train_spec, + } + + # Add optional fields + if settings.ttl_seconds_after_finished is not None: + manifest["spec"]["ttlSecondsAfterFinished"] = ( + settings.ttl_seconds_after_finished + ) + + return manifest + + def _wait_for_training_job_completion( + self, + job_name: str, + timeout: int, + ) -> bool: + """Wait for a training job to complete. + + Args: + job_name: Name of the training job. + timeout: Maximum time to wait in seconds. + + Returns: + True if job completed successfully, False if failed. + + Raises: + TimeoutError: If job doesn't complete within timeout. + """ + api = self._k8s_custom_objects_api + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = api.get_namespaced_custom_object( + group="trainer.kubeflow.org", + version="v1alpha1", + namespace=self.config.kubeflow_namespace, + plural="trainjobs", + name=job_name, + ) + + status = response.get("status", {}) + conditions = status.get("conditions", []) + + for condition in conditions: + if ( + condition.get("type") == "Succeeded" + and condition.get("status") == "True" + ): + logger.info( + f"Training job {job_name} completed successfully" + ) + return True + elif ( + condition.get("type") == "Failed" + and condition.get("status") == "True" + ): + logger.error( + f"Training job {job_name} failed: {condition.get('reason', 'Unknown')}" + ) + return False + + logger.info(f"Training job {job_name} is still running...") + time.sleep(30) # Check every 30 seconds + + except Exception as e: + logger.warning(f"Error checking job status: {e}") + time.sleep(30) + + raise TimeoutError( + f"Training job {job_name} did not complete within {timeout} seconds" + ) + + def launch( + self, + info: "StepRunInfo", + entrypoint_command: List[str], + environment: Dict[str, str], + ) -> None: + """Launches a training job on Kubeflow Training Operator. + + Args: + info: Information about the step run. + entrypoint_command: Command that executes the step. + environment: Environment variables to set in the step operator + environment. + """ + settings = cast( + KubeflowTrainingStepOperatorSettings, self.get_settings(info) + ) + image_name = info.get_image( + key=KUBEFLOW_TRAINING_STEP_OPERATOR_DOCKER_IMAGE_KEY + ) + + # Create a shorter job name to avoid 63-character limit + # Format: {pipeline_name}-{step_name}-{short_run_id} + run_id_short = info.run_name.split('-')[-1][:8] # Last 8 chars of run ID + job_name = f"{info.pipeline.name}-{info.pipeline_step_name}-{run_id_short}" + job_name = kube_utils.sanitize_pod_name( + job_name, namespace=self.config.kubeflow_namespace + ) + + logger.info( + f"Launching Kubeflow {settings.training_job_type} training job: {job_name}" + ) + + # Build training job manifest + manifest = self._build_training_job_manifest( + job_name=job_name, + image=image_name, + command=entrypoint_command, + environment=environment, + settings=settings, + ) + + # Submit the training job + api = self._k8s_custom_objects_api + + try: + api.create_namespaced_custom_object( + group="trainer.kubeflow.org", + version="v1alpha1", + namespace=self.config.kubeflow_namespace, + plural="trainjobs", + body=manifest, + ) + logger.info(f"Successfully created training job {job_name}") + + except Exception as e: + logger.error(f"Failed to create training job {job_name}: {e}") + raise + + # Wait for job completion + try: + success = self._wait_for_training_job_completion( + job_name=job_name, + timeout=settings.job_completion_timeout, + ) + + if not success: + raise RuntimeError(f"Training job {job_name} failed") + + except TimeoutError: + logger.warning( + f"Training job {job_name} did not complete within timeout" + ) + raise + except Exception as e: + logger.error(f"Error waiting for training job completion: {e}") + raise + finally: + # Cleanup job if configured to do so + if settings.ttl_seconds_after_finished is None: + try: + api.delete_namespaced_custom_object( + group="trainer.kubeflow.org", + version="v1alpha1", + namespace=self.config.kubeflow_namespace, + plural="trainjobs", + name=job_name, + ) + logger.info(f"Cleaned up training job {job_name}") + except Exception as e: + logger.warning( + f"Failed to cleanup training job {job_name}: {e}" + )