Skip to content

Commit 44edd6e

Browse files
committed
Switch to pytest
1 parent e6e6404 commit 44edd6e

File tree

2 files changed

+26
-40
lines changed

2 files changed

+26
-40
lines changed

ci/L0_additional_outputs_vllm/additional_outputs_test.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import json
28-
import unittest
2928

3029
import numpy as np
30+
import pytest
3131
import tritonclient.grpc as grpcclient
3232

3333

34-
class InferTest(unittest.TestCase):
34+
class TestAdditionalOutputs:
3535
_grpc_url = "localhost:8001"
3636
_model_name = "vllm_opt"
3737
_sampling_parameters = {"temperature": "0", "top_p": "1"}
@@ -93,51 +93,51 @@ def _llm_infer(self, inputs):
9393
self._model_name, inputs=inputs, parameters=self._sampling_parameters
9494
)
9595
client.stop_stream()
96-
self.assertGreater(len(self._responses), 0)
96+
assert len(self._responses) > 0
9797

9898
def _assert_text_output_valid(self):
9999
text_output = ""
100100
for response in self._responses:
101101
result, error = response["result"], response["error"]
102-
self.assertIsNone(error)
102+
assert error is None
103103
text_output += result.as_numpy(name="text_output")[0].decode("utf-8")
104-
self.assertGreater(len(text_output), 0, "output is empty")
105-
self.assertGreater(text_output.count(" "), 4, "output is not a sentence")
104+
assert len(text_output) > 0, "output is empty"
105+
assert text_output.count(" ") > 4, "output is not a sentence"
106106

107107
def _assert_finish_reason(self, output_finish_reason):
108108
for i in range(len(self._responses)):
109109
result, error = self._responses[i]["result"], self._responses[i]["error"]
110-
self.assertIsNone(error)
110+
assert error is None
111111
finish_reason_np = result.as_numpy(name="finish_reason")
112112
if output_finish_reason is None or output_finish_reason == False:
113-
self.assertIsNone(finish_reason_np)
113+
assert finish_reason_np is None
114114
continue
115115
finish_reason = finish_reason_np[0].decode("utf-8")
116116
if i < len(self._responses) - 1:
117-
self.assertEqual(finish_reason, "None")
117+
assert finish_reason == "None"
118118
else:
119-
self.assertEqual(finish_reason, "length")
119+
assert finish_reason == "length"
120120

121121
def _assert_cumulative_logprob(self, output_cumulative_logprob):
122122
prev_cumulative_logprob = 0.0
123123
for response in self._responses:
124124
result, error = response["result"], response["error"]
125-
self.assertIsNone(error)
125+
assert error is None
126126
cumulative_logprob_np = result.as_numpy(name="cumulative_logprob")
127127
if output_cumulative_logprob is None or output_cumulative_logprob == False:
128-
self.assertIsNone(cumulative_logprob_np)
128+
assert cumulative_logprob_np is None
129129
continue
130130
cumulative_logprob = cumulative_logprob_np[0].astype(float)
131-
self.assertNotEqual(cumulative_logprob, prev_cumulative_logprob)
131+
assert cumulative_logprob != prev_cumulative_logprob
132132
prev_cumulative_logprob = cumulative_logprob
133133

134134
def _assert_num_token_ids(self, output_num_token_ids):
135135
for response in self._responses:
136136
result, error = response["result"], response["error"]
137-
self.assertIsNone(error)
137+
assert error is None
138138
num_token_ids_np = result.as_numpy(name="num_token_ids")
139139
if output_num_token_ids is None or output_num_token_ids == False:
140-
self.assertIsNone(num_token_ids_np)
140+
assert num_token_ids_np is None
141141
continue
142142
num_token_ids = num_token_ids_np[0].astype(int)
143143
# TODO: vLLM may return token ids identical to the previous one when
@@ -156,10 +156,14 @@ def _assert_num_token_ids(self, output_num_token_ids):
156156
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
157157
#
158158
# If this is no longer the case in a future release, change the assert
159-
# to assertGreater().
160-
self.assertGreaterEqual(num_token_ids, 0)
161-
162-
def _assert_additional_outputs_valid(
159+
# to assert num_token_ids > 0.
160+
assert num_token_ids >= 0
161+
162+
@pytest.mark.parametrize("stream", [True, False])
163+
@pytest.mark.parametrize("output_finish_reason", [None, True, False])
164+
@pytest.mark.parametrize("output_cumulative_logprob", [None, True, False])
165+
@pytest.mark.parametrize("output_num_token_ids", [None, True, False])
166+
def test_additional_outputs(
163167
self,
164168
stream,
165169
output_finish_reason,
@@ -179,20 +183,3 @@ def _assert_additional_outputs_valid(
179183
self._assert_finish_reason(output_finish_reason)
180184
self._assert_cumulative_logprob(output_cumulative_logprob)
181185
self._assert_num_token_ids(output_num_token_ids)
182-
183-
def test_additional_outputs(self):
184-
for stream in [True, False]:
185-
choices = [None, False, True]
186-
for output_finish_reason in choices:
187-
for output_cumulative_logprob in choices:
188-
for output_num_token_ids in choices:
189-
self._assert_additional_outputs_valid(
190-
stream,
191-
output_finish_reason,
192-
output_cumulative_logprob,
193-
output_num_token_ids,
194-
)
195-
196-
197-
if __name__ == "__main__":
198-
unittest.main()

ci/L0_additional_outputs_vllm/test.sh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
export CUDA_VISIBLE_DEVICES=0
2929
source ../common/util.sh
3030

31+
pip3 install pytest==8.1.1
3132
pip3 install tritonclient[grpc]
3233

3334
# Prepare Model
@@ -38,8 +39,7 @@ sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/v
3839

3940
RET=0
4041

41-
# Infer Test
42-
CLIENT_LOG="vllm_opt.log"
42+
# Test
4343
SERVER_LOG="vllm_opt.server.log"
4444
SERVER_ARGS="--model-repository=models"
4545
run_server
@@ -49,9 +49,8 @@ if [ "$SERVER_PID" == "0" ]; then
4949
exit 1
5050
fi
5151
set +e
52-
python3 additional_outputs_test.py > $CLIENT_LOG 2>&1
52+
python3 -m pytest -s -v additional_outputs_test.py
5353
if [ $? -ne 0 ]; then
54-
cat $CLIENT_LOG
5554
echo -e "\n***\n*** additional_outputs_test FAILED. \n***"
5655
RET=1
5756
fi

0 commit comments

Comments
 (0)