Skip to content

Commit 8cbb0cb

Browse files
committed
Expect one GPU node, update nvidia pod count verification
1 parent 96e9184 commit 8cbb0cb

File tree

1 file changed

+21
-33
lines changed

1 file changed

+21
-33
lines changed

tests/interop/test_validate_gpu_nodes.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import subprocess
55

66
import pytest
7+
from ocp_resources.pod import Pod
78
from ocp_resources.machine_set import MachineSet
89
from ocp_resources.node import Node
910

@@ -113,6 +114,7 @@ def test_validate_gpu_node_role_labels_pods(openshift_dyn_client):
113114

114115
nodes = Node.get(dyn_client=openshift_dyn_client)
115116
gpu_nodes = []
117+
expected_count = 1
116118
for node in nodes:
117119
logger.info(node.instance.metadata.name)
118120
labels = node.instance.metadata.labels
@@ -125,9 +127,7 @@ def test_validate_gpu_node_role_labels_pods(openshift_dyn_client):
125127
if odh_label in label_str and worker_label in label_str:
126128
gpu_nodes.append(node)
127129

128-
# logger.info(node_count)
129-
130-
if len(gpu_nodes) == 3:
130+
if len(gpu_nodes) == int(expected_count):
131131
logger.info("PASS: Found 'worker' and 'odh-notebook' GPU node-role labels")
132132
else:
133133
err_msg = "Could not find 'worker' and 'odh-notebook' GPU node-role label"
@@ -139,35 +139,23 @@ def test_validate_gpu_node_role_labels_pods(openshift_dyn_client):
139139
"""
140140
logger.info("Checking pod count on GPU nodes")
141141

142-
for gpu_node in gpu_nodes:
143-
name = gpu_node.instance.metadata.name
144-
field_select = "--field-selector=spec.host=" + name
145-
pod_count = 0
146-
expected_count = 20
147-
failed_nodes = []
148-
cmd_out = subprocess.run(
149-
[oc, "get", "pod", "-A", field_select, "--no-headers"], capture_output=True
150-
)
151-
152-
if cmd_out.stdout:
153-
out_decoded = cmd_out.stdout.decode("utf-8")
154-
logger.info(node.instance.metadata.name + "\n" + out_decoded)
155-
out_split = out_decoded.splitlines()
156-
157-
for line in out_split:
158-
if "Completed" in line:
159-
continue
160-
else:
161-
pod_count += 1
162-
163-
if pod_count < expected_count:
164-
failed_nodes.append(node.instance.metadata.name)
165-
else:
166-
assert False, cmd_out.stderr
167-
168-
if failed_nodes:
169-
err_msg = f"Did not find the expected pod count on: {failed_nodes}"
142+
# We are assuming one GPU node
143+
gpu_node = gpu_nodes[0].instance.metadata.name
144+
nvidia_pods = []
145+
expected_count = 8
146+
project = "nvidia-gpu-operator"
147+
pods = Pod.get(dyn_client=openshift_dyn_client, namespace=project)
148+
149+
for pod in pods:
150+
if "nvidia" in pod.instance.metadata.name:
151+
logger.info(f"nvidia pod: {pod.instance.metadata.name}")
152+
if gpu_node in pod.instance.spec.nodeName:
153+
logger.info(f"nvidia pod node name: {pod.instance.spec.nodeName}")
154+
nvidia_pods.append(pod.instance.metadata.name)
155+
156+
if len(nvidia_pods) == int(expected_count):
157+
logger.info("PASS: Found the expected nvidia pod count for GPU nodes")
158+
else:
159+
err_msg = f"Did not find the expected nvidia pod count for GPU nodes"
170160
logger.error(f"FAIL: {err_msg}")
171161
assert False, err_msg
172-
else:
173-
logger.info("PASS: Found the expected pod count for GPU nodes")

0 commit comments

Comments
 (0)