Skip to content

Commit 91197e2

Browse files
authored
refactor: Fix type annotations and null checks in LMEval classes (#52)
1 parent 8378a2c commit 91197e2

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

src/llama_stack_provider_lmeval/lmeval.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def __init__(self, namespace: str = "default", service_account: str | None = Non
325325
"""
326326
self._namespace = namespace
327327
self._service_account = service_account
328-
self._config = None
328+
self._config: LMEvalEvalProviderConfig | None = None
329329

330330
@staticmethod
331331
def _build_openai_url(base_url: str) -> str:
@@ -347,7 +347,7 @@ def _build_openai_url(base_url: str) -> str:
347347
return f"{cleaned_url}/v1/completions"
348348

349349
def _create_model_args(
350-
self, base_url: str, benchmark_config: BenchmarkConfig
350+
self, base_url: str | None, benchmark_config: BenchmarkConfig
351351
) -> list[ModelArg]:
352352
"""Create model arguments for the LMEvalJob CR."""
353353
model_args = [
@@ -749,7 +749,8 @@ def create_cr(
749749
model_args = self._create_model_args(base_url, task_config)
750750

751751
if (
752-
hasattr(stored_benchmark, "metadata")
752+
stored_benchmark is not None
753+
and hasattr(stored_benchmark, "metadata")
753754
and stored_benchmark.metadata
754755
and "tokenizer" in stored_benchmark.metadata
755756
):
@@ -762,7 +763,8 @@ def create_cr(
762763

763764
# Add tokenized_requests parameter if present in metadata
764765
if (
765-
hasattr(stored_benchmark, "metadata")
766+
stored_benchmark is not None
767+
and hasattr(stored_benchmark, "metadata")
766768
and stored_benchmark.metadata
767769
and "tokenized_requests" in stored_benchmark.metadata
768770
):
@@ -832,7 +834,7 @@ def create_cr(
832834
if git_source_data:
833835
logger.info("Adding customTasks to CR with git data: %s", git_source_data)
834836

835-
custom_tasks_section = {"source": {"git": {}}}
837+
custom_tasks_section: dict[str, Any] = {"source": {"git": {}}}
836838

837839
for key, value in git_source_data.items():
838840
if value is not None:
@@ -926,12 +928,12 @@ def __init__(self, config: LMEvalEvalProviderConfig):
926928

927929
logger.debug("LMEval provider initialized with namespace: %s", self._namespace)
928930
logger.debug("LMEval provider config values: %s", vars(self._config))
929-
self.benchmarks = {}
931+
self.benchmarks: dict[str, Benchmark] = {}
930932
self._jobs: list[Job] = []
931-
self._job_metadata = {}
933+
self._job_metadata: dict[str, dict[str, Any]] = {}
932934

933-
self._k8s_client = None
934-
self._k8s_custom_api = None
935+
self._k8s_client: k8s_client.ApiClient | None = None
936+
self._k8s_custom_api: k8s_client.CustomObjectsApi | None = None
935937
if self.use_k8s:
936938
self._init_k8s_client()
937939
logger.debug(
@@ -1045,8 +1047,10 @@ def _deploy_lmeval_cr(self, cr: dict, job_id: str) -> None:
10451047
if "spec" in cr:
10461048
pvc_name = None
10471049

1048-
if hasattr(self._cr_builder, "_config") and hasattr(
1049-
self._cr_builder._config, "metadata"
1050+
if (
1051+
self._cr_builder._config is not None
1052+
and hasattr(self._cr_builder._config, "metadata")
1053+
and self._cr_builder._config.metadata
10501054
):
10511055
config_metadata = self._cr_builder._config.metadata
10521056
if (
@@ -1096,6 +1100,9 @@ def _deploy_lmeval_cr(self, cr: dict, job_id: str) -> None:
10961100
)
10971101
logger.info("Full Custom Resource being submitted: \n%s", cr_yaml)
10981102

1103+
if self._k8s_custom_api is None:
1104+
raise LMEvalConfigError("Kubernetes custom API not initialized")
1105+
10991106
try:
11001107
response = self._k8s_custom_api.create_namespaced_custom_object(
11011108
group=group,
@@ -1324,6 +1331,10 @@ async def job_status(self, benchmark_id: str, job_id: str) -> dict[str, str] | N
13241331
logger.warning("Job %s not found", job_id)
13251332
return None
13261333

1334+
if self._k8s_custom_api is None:
1335+
logger.warning("Kubernetes custom API not initialized")
1336+
return {"job_id": job_id, "status": JobStatus.scheduled}
1337+
13271338
try:
13281339
job_metadata = self._job_metadata.get(job_id, {})
13291340
k8s_name = job_metadata.get("k8s_name")
@@ -1392,6 +1403,10 @@ async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
13921403
logger.warning("Job %s not found", job_id)
13931404
return
13941405

1406+
if self._k8s_custom_api is None:
1407+
logger.warning("Kubernetes custom API not initialized")
1408+
return
1409+
13951410
try:
13961411
job_metadata = self._job_metadata.get(job_id, {})
13971412
k8s_name = job_metadata.get("k8s_name")
@@ -1455,6 +1470,10 @@ def _get_k8s_cr(self, k8s_name: str) -> dict[str, Any] | None:
14551470
Returns:
14561471
Custom resource as dictionary or None if not found
14571472
"""
1473+
if self._k8s_custom_api is None:
1474+
logger.warning("Kubernetes custom API not initialized")
1475+
return None
1476+
14581477
try:
14591478
group = "trustyai.opendatahub.io"
14601479
version = "v1alpha1"

0 commit comments

Comments
 (0)