@@ -168,6 +168,8 @@ triton_to_pybind_dtype(TRITONSERVER_DataType data_type)
168
168
dtype_numpy = py::dtype (py::format_descriptor<uint8_t >::format ());
169
169
break ;
170
170
case TRITONSERVER_TYPE_BF16:
171
+ // Currently skipping this call via `if (BF16)` check, but probably
172
+ // need to handle this or set some default/invalid dtype.
171
173
throw PythonBackendException (" TYPE_BF16 not currently supported." );
172
174
case TRITONSERVER_TYPE_INVALID:
173
175
throw PythonBackendException (" Dtype is invalid." );
@@ -240,6 +242,10 @@ triton_to_dlpack_type(TRITONSERVER_DataType triton_dtype)
240
242
case TRITONSERVER_TYPE_BYTES:
241
243
throw PythonBackendException (
242
244
" TYPE_BYTES tensors cannot be converted to DLPack." );
245
+ case TRITONSERVER_TYPE_BF16:
246
+ dl_code = DLDataTypeCode::kDLBfloat ;
247
+ dt_size = 16 ;
248
+ break ;
243
249
244
250
default :
245
251
throw PythonBackendException (
@@ -301,6 +307,15 @@ dlpack_to_triton_type(const DLDataType& data_type)
301
307
}
302
308
}
303
309
310
+ if (data_type.code == DLDataTypeCode::kDLBfloat ) {
311
+ if (data_type.bits != 16 ) {
312
+ throw PythonBackendException (
313
+ " Expected BF16 tensor to have 16 bits, but had: " +
314
+ std::to_string (data_type.bits ));
315
+ }
316
+ return TRITONSERVER_TYPE_BF16;
317
+ }
318
+
304
319
return TRITONSERVER_TYPE_INVALID;
305
320
}
306
321
}}} // namespace triton::backend::python
0 commit comments