44import re
55import sys
66from pathlib import Path
7- from typing import List , Optional , Union
7+ from typing import List , Optional , Tuple
88from uuid import uuid4
99
1010import kubernetes
1111import kubernetes .client .models as kubernetes_models
12+ from typing_extensions import Literal
1213
1314from .cluster_executor import ClusterExecutor
1415
@@ -19,9 +20,9 @@ def volume_name_from_path(path: Path) -> str:
1920
2021def deduplicate_mounts (mounts : List [Path ]) -> List [Path ]:
2122 output = []
22- mounts = set (mounts )
23- for mount in mounts :
24- if not any (m in mount .parents for m in mounts ):
23+ unique_mounts = set (mounts )
24+ for mount in unique_mounts :
25+ if not any (m in mount .parents for m in unique_mounts ):
2526 output .append (mount )
2627 return output
2728
@@ -73,7 +74,7 @@ def get_current_job_id() -> Optional[str]:
7374 return os .environ .get ("JOB_ID" , None )
7475
7576 @classmethod
76- def get_job_id_string (cls ) -> str :
77+ def get_job_id_string (cls ) -> Optional [ str ] :
7778 job_id = cls .get_current_job_id ()
7879 job_index = cls .get_job_array_index ()
7980 if job_index is None :
@@ -104,21 +105,21 @@ def inner_submit(
104105 self ,
105106 cmdline : str ,
106107 job_name : Optional [str ] = None ,
108+ additional_setup_lines : Optional [List [str ]] = None ,
107109 job_count : Optional [int ] = None ,
108- ** _ ,
109- ):
110+ ) -> Tuple [List ["concurrent.futures.Future[str]" ], List [Tuple [int , int ]]]:
110111 """Starts a Kubernetes pod that runs the specified shell command line."""
111112
112113 kubernetes_client = KubernetesClient ()
113114 self .ensure_kubernetes_namespace ()
114115 job_id = str (uuid4 ())
115116
116- job_id_future = concurrent .futures .Future ()
117+ job_id_future : "concurrent.futures.Future[str]" = concurrent .futures .Future ()
117118 job_id_future .set_result (job_id )
118119 job_id_futures = [job_id_future ]
119120
120121 is_array_job = job_count is not None
121- number_of_subjobs = job_count if is_array_job else 1
122+ number_of_subjobs = job_count if job_count is not None else 1
122123 ranges = [(0 , number_of_subjobs )]
123124
124125 requested_resources = {
@@ -232,12 +233,12 @@ def inner_submit(
232233
233234 def check_for_crashed_job (
234235 self , job_id_with_index : str
235- ) -> Union ["failed" , "ignore" , "completed" ]:
236+ ) -> Literal ["failed" , "ignore" , "completed" ]:
236237 kubernetes_client = KubernetesClient ()
237238 [job_id , job_index ] = (
238239 job_id_with_index .split ("_" )
239240 if "_" in job_id_with_index
240- else [job_id_with_index , 0 ]
241+ else [job_id_with_index , "0" ]
241242 )
242243 resp = kubernetes_client .core .list_namespaced_pod (
243244 namespace = self .job_resources ["namespace" ],
0 commit comments