2525// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727#include < stdint.h>
28+
2829#include < exception>
30+
2931#include " libtorch_utils.h"
3032#include " triton/backend/backend_common.h"
3133#include " triton/backend/backend_input_collector.h"
5355#include < cuda_runtime_api.h>
5456#endif // TRITON_ENABLE_GPU
5557
58+ // Default forward method to call on PyTorch modules
59+ const std::string DEFAULT_MODULE_METHOD_NAME = " forward" ;
60+
5661//
5762// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
5863//
@@ -103,6 +108,7 @@ class ModelState : public BackendModel {
103108
104109 bool EnabledWeightSharing () { return enable_weight_sharing_; }
105110 const std::vector<std::string>& ModelOutputs () { return output_names_; }
111+ const std::string& ModuleMethodName () { return module_method_name_; }
106112
107113 private:
108114 ModelState (TRITONBACKEND_Model* triton_model);
@@ -145,6 +151,10 @@ class ModelState : public BackendModel {
145151 // List of all the outputs specified in the output section of model
146152 // configuration.
147153 std::vector<std::string> output_names_;
154+
155+ // Method to call on PyTorch Module.
156+ // Defaults to DEFAULT_MODULE_METHOD_NAME.
157+ std::string module_method_name_;
148158};
149159
150160TRITONSERVER_Error*
@@ -180,7 +190,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180190 enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
181191 enable_jit_profiling_pair_({false , true }),
182192 enable_jit_executor_pair_({false , true }),
183- enable_nvfuser_pair_({false , false })
193+ enable_nvfuser_pair_({false , false }),
194+ module_method_name_(DEFAULT_MODULE_METHOD_NAME)
184195{
185196 output_names_.clear ();
186197
@@ -454,6 +465,30 @@ ModelState::ParseParameters()
454465 " for model instance '" + Name () + " '" )
455466 .c_str ());
456467 }
468+
469+ // If 'MODULE_METHOD_NAME' is not present in 'parameters' then
470+ // 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward').
471+ std::string module_method_name = DEFAULT_MODULE_METHOD_NAME;
472+ err = GetParameterValue (params, " MODULE_METHOD_NAME" , &module_method_name);
473+ if (err != nullptr ) {
474+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
475+ return err;
476+ } else {
477+ LOG_MESSAGE (
478+ TRITONSERVER_LOG_INFO,
479+ (std::string (" module_method_name is not specified" ) +
480+ " for model instance '" + Name () + " '" )
481+ .c_str ());
482+ TRITONSERVER_ErrorDelete (err);
483+ }
484+ } else {
485+ module_method_name_ = module_method_name;
486+ LOG_MESSAGE (
487+ TRITONSERVER_LOG_INFO,
488+ (std::string (" module_method_name is " ) + module_method_name_ +
489+ " for model instance '" + Name () + " '" )
490+ .c_str ());
491+ }
457492 }
458493
459494 return nullptr ;
@@ -764,7 +799,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764799 // configuration specifies only those.
765800 std::vector<std::string> allowed_inputs;
766801
767- const torch::jit::Method& method = torch_model_->get_method (" forward" );
802+ const torch::jit::Method& method =
803+ torch_model_->get_method (model_state_->ModuleMethodName ());
768804 const auto & schema = method.function ().getSchema ();
769805 const std::vector<c10::Argument>& arguments = schema.arguments ();
770806
@@ -1312,30 +1348,36 @@ ModelInstanceState::Execute(
13121348 torch::jit::overrideCanFuseOnCPU (false );
13131349 torch::jit::overrideCanFuseOnGPU (false );
13141350 torch::jit::setTensorExprFuserEnabled (false );
1315- torch::jit::fuser::cuda::setEnabled (true );
1351+ torch::jit::fuser::cuda::setEnabled (true );
13161352 } else {
13171353 torch::jit::overrideCanFuseOnCPU (true );
13181354 torch::jit::overrideCanFuseOnGPU (true );
13191355 torch::jit::setTensorExprFuserEnabled (true );
1320- torch::jit::fuser::cuda::setEnabled (false );
1356+ torch::jit::fuser::cuda::setEnabled (false );
13211357 }
13221358 }
13231359
13241360 torch::NoGradGuard no_grad;
13251361
13261362 // If input is a dictionary, prepare dictionary from 'input_tensors'.
1363+ std::string module_method_name = model_state_->ModuleMethodName ();
1364+ std::vector<c10::IValue> inputs;
13271365 if (is_dict_input_) {
1328- torch ::Dict<std::string, torch ::Tensor> input_dict ;
1366+ c10 ::Dict<std::string, at ::Tensor> dict ;
13291367 for (auto & input_index : input_index_map_) {
13301368 torch::jit::IValue ival = (*input_tensors)[input_index.second ];
1331- input_dict .insert (input_index.first , ival.toTensor ());
1369+ dict .insert (input_index.first , ival.toTensor ());
13321370 }
1333- std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334- model_outputs_ = torch_model_->forward (input_dict_ivalue);
1371+ inputs.push_back (dict);
13351372 } else {
1336- model_outputs_ = torch_model_->forward (*input_tensors);
1373+ for (auto & input_tensor : *input_tensors) {
1374+ inputs.push_back (input_tensor.toTensor ());
1375+ }
13371376 }
13381377
1378+ // Actually run the method on the model.
1379+ model_outputs_ = torch_model_->get_method (module_method_name)(inputs);
1380+
13391381 if (model_outputs_.isTuple ()) {
13401382 auto model_outputs_tuple = model_outputs_.toTuple ();
13411383 size_t op_index = 0 ;
@@ -1761,9 +1803,9 @@ ModelInstanceState::SetInputTensors(
17611803
17621804 batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
17631805 }
1764- }
1765- else {
1766- batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1806+ } else {
1807+ batchn_shape =
1808+ std::vector<int64_t >(input_shape, input_shape + input_dims_count);
17671809 if (supports_batching_) {
17681810 batchn_shape[0 ] = total_batch_size;
17691811 }
@@ -1772,8 +1814,8 @@ ModelInstanceState::SetInputTensors(
17721814 // The input must be in contiguous CPU/GPU memory.
17731815 std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
17741816 if (device_.is_cpu ()) {
1775- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1776- {TRITONSERVER_MEMORY_CPU, 0 }};
1817+ alloc_perference = {
1818+ {TRITONSERVER_MEMORY_CPU_PINNED, 0 }, {TRITONSERVER_MEMORY_CPU, 0 }};
17771819 } else {
17781820 alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
17791821 }
@@ -1887,9 +1929,11 @@ ModelInstanceState::ReadOutputTensors(
18871929
18881930 // Output tensors may not reside on the same device as model
18891931 torch::Device tensor_device = output_flat.device ();
1890- const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891- : TRITONSERVER_MEMORY_GPU;
1892- const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1932+ const auto memory_type = (tensor_device.type () == torch::kCPU )
1933+ ? TRITONSERVER_MEMORY_CPU
1934+ : TRITONSERVER_MEMORY_GPU;
1935+ const auto memory_id =
1936+ (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
18931937
18941938 // Batch output doesn't support string data type yet, as it is not trivial
18951939 // to parse string output
@@ -1906,16 +1950,16 @@ ModelInstanceState::ReadOutputTensors(
19061950 return TRITONSERVER_ErrorNew (
19071951 TRITONSERVER_ERROR_INVALID_ARG,
19081952 (std::string (" output '" ) + name +
1909- " ' is a scalar which is not supported." )
1953+ " ' is a scalar which is not supported." )
19101954 .c_str ());
19111955 }
19121956
19131957 responder.ProcessTensor (
1914- name, output_dtype, batchn_shape, output_buffer,
1915- memory_type, memory_id);
1958+ name, output_dtype, batchn_shape, output_buffer, memory_type,
1959+ memory_id);
19161960 } else {
19171961 responder.ProcessBatchOutput (
1918- name, *batch_output, output_buffer, memory_type, memory_id);
1962+ name, *batch_output, output_buffer, memory_type, memory_id);
19191963 }
19201964 } else if (output_tensors[op_index].isList ()) {
19211965 // Custom handling for string/bytes tensor...
0 commit comments