diff --git a/tests/interop/test_validate_gpu_nodes.py b/tests/interop/test_validate_gpu_nodes.py index a03f9fb8..f38c8efc 100644 --- a/tests/interop/test_validate_gpu_nodes.py +++ b/tests/interop/test_validate_gpu_nodes.py @@ -1,11 +1,11 @@ import logging import os import re -import subprocess import pytest from ocp_resources.machine_set import MachineSet from ocp_resources.node import Node +from ocp_resources.pod import Pod from . import __loggername__ @@ -113,6 +113,7 @@ def test_validate_gpu_node_role_labels_pods(openshift_dyn_client): nodes = Node.get(dyn_client=openshift_dyn_client) gpu_nodes = [] + expected_count = 1 for node in nodes: logger.info(node.instance.metadata.name) labels = node.instance.metadata.labels @@ -125,9 +126,7 @@ def test_validate_gpu_node_role_labels_pods(openshift_dyn_client): if odh_label in label_str and worker_label in label_str: gpu_nodes.append(node) - # logger.info(node_count) - - if len(gpu_nodes) == 3: + if len(gpu_nodes) == int(expected_count): logger.info("PASS: Found 'worker' and 'odh-notebook' GPU node-role labels") else: err_msg = "Could not find 'worker' and 'odh-notebook' GPU node-role label" @@ -139,35 +138,23 @@ def test_validate_gpu_node_role_labels_pods(openshift_dyn_client): """ logger.info("Checking pod count on GPU nodes") - for gpu_node in gpu_nodes: - name = gpu_node.instance.metadata.name - field_select = "--field-selector=spec.host=" + name - pod_count = 0 - expected_count = 20 - failed_nodes = [] - cmd_out = subprocess.run( - [oc, "get", "pod", "-A", field_select, "--no-headers"], capture_output=True - ) - - if cmd_out.stdout: - out_decoded = cmd_out.stdout.decode("utf-8") - logger.info(node.instance.metadata.name + "\n" + out_decoded) - out_split = out_decoded.splitlines() - - for line in out_split: - if "Completed" in line: - continue - else: - pod_count += 1 - - if pod_count < expected_count: - failed_nodes.append(node.instance.metadata.name) - else: - assert False, cmd_out.stderr - - if failed_nodes: - err_msg = f"Did not find the expected pod count on: {failed_nodes}" + # We are assuming one GPU node + gpu_node = gpu_nodes[0].instance.metadata.name + nvidia_pods = [] + expected_count = 8 + project = "nvidia-gpu-operator" + pods = Pod.get(dyn_client=openshift_dyn_client, namespace=project) + + for pod in pods: + if "nvidia" in pod.instance.metadata.name: + logger.info(f"nvidia pod: {pod.instance.metadata.name}") + if gpu_node in pod.instance.spec.nodeName: + logger.info(f"nvidia pod node name: {pod.instance.spec.nodeName}") + nvidia_pods.append(pod.instance.metadata.name) + + if len(nvidia_pods) == int(expected_count): + logger.info("PASS: Found the expected nvidia pod count for GPU nodes") + else: + err_msg = "Did not find the expected nvidia pod count for GPU nodes" logger.error(f"FAIL: {err_msg}") assert False, err_msg - else: - logger.info("PASS: Found the expected pod count for GPU nodes")