Skip to content

Commit 69d768d

Browse files
authored
test: Add BF16 test for python backend (#7483)
1 parent fb056b1 commit 69d768d

File tree

4 files changed

+172
-0
lines changed

4 files changed

+172
-0
lines changed

qa/L0_backend_python/python_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,36 @@ def test_bool(self):
365365
self.assertIsNotNone(output0)
366366
self.assertTrue(np.all(output0 == input_data))
367367

368+
def test_bf16(self):
369+
model_name = "identity_bf16"
370+
shape = [2, 2]
371+
with self._shm_leak_detector.Probe() as shm_probe:
372+
with httpclient.InferenceServerClient(
373+
f"{_tritonserver_ipaddr}:8000"
374+
) as client:
375+
# NOTE: Client will truncate FP32 to BF16 internally
376+
# since numpy has no built-in BF16 representation.
377+
np_input = np.ones(shape, dtype=np.float32)
378+
inputs = [
379+
httpclient.InferInput(
380+
"INPUT0", np_input.shape, "BF16"
381+
).set_data_from_numpy(np_input)
382+
]
383+
result = client.infer(model_name, inputs)
384+
385+
# Assert that Triton correctly returned a BF16 tensor.
386+
response = result.get_response()
387+
triton_output = response["outputs"][0]
388+
triton_dtype = triton_output["datatype"]
389+
self.assertEqual(triton_dtype, "BF16")
390+
391+
np_output = result.as_numpy("OUTPUT0")
392+
self.assertIsNotNone(np_output)
393+
# BF16 tensors are held in FP32 when converted to numpy due to
394+
# lack of native BF16 support in numpy, so verify that.
395+
self.assertEqual(np_output.dtype, np.float32)
396+
self.assertTrue(np.allclose(np_output, np_input))
397+
368398
def test_infer_pytorch(self):
369399
# FIXME: This model requires torch. Because windows tests are not run in a docker
370400
# environment with torch installed, we need to think about how we want to install

qa/L0_backend_python/test.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ fi
9595
mkdir -p models/identity_fp32/1/
9696
cp ../python_models/identity_fp32/model.py ./models/identity_fp32/1/model.py
9797
cp ../python_models/identity_fp32/config.pbtxt ./models/identity_fp32/config.pbtxt
98+
mkdir -p models/identity_bf16/1/
99+
cp ../python_models/identity_bf16/model.py ./models/identity_bf16/1/model.py
100+
cp ../python_models/identity_bf16/config.pbtxt ./models/identity_bf16/config.pbtxt
98101
RET=0
99102

100103
cp -r ./models/identity_fp32 ./models/identity_uint8
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
backend: "python"
28+
max_batch_size: 64
29+
30+
input [
31+
{
32+
name: "INPUT0"
33+
data_type: TYPE_BF16
34+
dims: [ -1 ]
35+
}
36+
]
37+
38+
output [
39+
{
40+
name: "OUTPUT0"
41+
data_type: TYPE_BF16
42+
dims: [ -1 ]
43+
}
44+
]
45+
46+
instance_group [
47+
{
48+
count: 1
49+
kind : KIND_CPU
50+
}
51+
]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import json
28+
29+
import torch
30+
import triton_python_backend_utils as pb_utils
31+
32+
33+
class TritonPythonModel:
34+
def initialize(self, args):
35+
# You must parse model_config. JSON string is not parsed here
36+
self.model_config = json.loads(args["model_config"])
37+
38+
# Get tensor configurations for testing/validation
39+
self.input0_config = pb_utils.get_input_config_by_name(
40+
self.model_config, "INPUT0"
41+
)
42+
self.output0_config = pb_utils.get_output_config_by_name(
43+
self.model_config, "OUTPUT0"
44+
)
45+
46+
def validate_bf16_tensor(self, tensor, tensor_config):
47+
# I/O datatypes can be queried from the model config if needed
48+
dtype = tensor_config["data_type"]
49+
if dtype != "TYPE_BF16":
50+
raise Exception(f"Expected a BF16 tensor, but got {dtype} instead.")
51+
52+
# Converting BF16 tensors to numpy is not supported, and DLPack
53+
# should be used instead via to_dlpack and from_dlpack.
54+
try:
55+
_ = tensor.as_numpy()
56+
except pb_utils.TritonModelException as e:
57+
expected_error = "tensor dtype is bf16 and cannot be converted to numpy"
58+
assert expected_error in str(e).lower()
59+
else:
60+
raise Exception("Expected BF16 conversion to numpy to fail")
61+
62+
def execute(self, requests):
63+
"""
64+
Identity model in Python backend with example BF16 and PyTorch usage.
65+
"""
66+
responses = []
67+
for request in requests:
68+
input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0")
69+
70+
# Numpy does not support BF16, so use DLPack instead.
71+
bf16_dlpack = input_tensor.to_dlpack()
72+
73+
# OPTIONAL: The tensor can be converted to other dlpack-compatible
74+
# frameworks like PyTorch and TensorFlow with their dlpack utilities.
75+
torch_tensor = torch.utils.dlpack.from_dlpack(bf16_dlpack)
76+
77+
# When complete, convert back to a pb_utils.Tensor via DLPack.
78+
output_tensor = pb_utils.Tensor.from_dlpack(
79+
"OUTPUT0", torch.utils.dlpack.to_dlpack(torch_tensor)
80+
)
81+
responses.append(pb_utils.InferenceResponse([output_tensor]))
82+
83+
# NOTE: The following helper function is for testing and example
84+
# purposes only, you should remove this in practice.
85+
self.validate_bf16_tensor(input_tensor, self.input0_config)
86+
self.validate_bf16_tensor(output_tensor, self.output0_config)
87+
88+
return responses

0 commit comments

Comments
 (0)