Skip to content

Commit c077c86

Browse files
authored
Add check for sequence data type (#93)
1 parent 2559db9 commit c077c86

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/libtorch.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -741,6 +741,15 @@ ModelInstanceState::ValidateTypedSequenceControl(
741741
}
742742
}
743743

744+
// check if the data type is supported by PyTorch
745+
if (!ModelConfigDataTypeToTorchType(tensor_datatype).first) {
746+
return TRITONSERVER_ErrorNew(
747+
TRITONSERVER_ERROR_INTERNAL,
748+
("input '" + tensor_name + "' type '" + tensor_datatype +
749+
"' is not supported by PyTorch.")
750+
.c_str());
751+
}
752+
744753
ip_index = std::atoi(tensor_name.substr(start_pos + 2).c_str());
745754
input_index_map_[tensor_name] = ip_index;
746755
}

0 commit comments

Comments
 (0)