|
1 | | -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. |
| 1 | +// Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved. |
2 | 2 | // |
3 | 3 | // Redistribution and use in source and binary forms, with or without |
4 | 4 | // modification, are permitted provided that the following conditions |
@@ -847,9 +847,15 @@ TRITONBACKEND_ModelInstanceExecute( |
847 | 847 | if (input_memory_type == TRITONSERVER_MEMORY_GPU) { |
848 | 848 | ipbuffer_vec.resize(input_element_cnt); |
849 | 849 | ipbuffer_int = ipbuffer_vec.data(); |
850 | | - cudaMemcpy( |
851 | | - const_cast<int32_t*>(ipbuffer_int), input_buffer, input_byte_size, |
852 | | - cudaMemcpyDeviceToHost); |
| 850 | + LOG_IF_CUDA_ERROR( |
| 851 | + cudaMemcpyAsync( |
| 852 | + const_cast<int32_t*>(ipbuffer_int), input_buffer, input_byte_size, |
| 853 | + cudaMemcpyDeviceToHost, instance_state->CudaStream()), |
| 854 | + "failed to copy buffer from Device to Host"); |
| 855 | + |
| 856 | + LOG_IF_CUDA_ERROR( |
| 857 | + cudaStreamSynchronize(instance_state->CudaStream()), |
| 858 | + "failed to perform synchronization on cuda stream"); |
853 | 859 | } else { |
854 | 860 | ipbuffer_int = reinterpret_cast<const int32_t*>(input_buffer); |
855 | 861 | } |
@@ -939,9 +945,15 @@ TRITONBACKEND_ModelInstanceExecute( |
939 | 945 | } |
940 | 946 |
|
941 | 947 | if (output_memory_type == TRITONSERVER_MEMORY_GPU) { |
942 | | - cudaMemcpy( |
943 | | - output_buffer, const_cast<int32_t*>(obuffer_int), |
944 | | - buffer_byte_size, cudaMemcpyHostToDevice); |
| 948 | + LOG_IF_CUDA_ERROR( |
| 949 | + cudaMemcpyAsync( |
| 950 | + output_buffer, const_cast<int32_t*>(obuffer_int), |
| 951 | + buffer_byte_size, cudaMemcpyHostToDevice, |
| 952 | + instance_state->CudaStream()), |
| 953 | + "failed to copy buffer from Device to Host"); |
| 954 | + LOG_IF_CUDA_ERROR( |
| 955 | + cudaStreamSynchronize(instance_state->CudaStream()), |
| 956 | + "failed to perform synchronization on cuda stream"); |
945 | 957 | } |
946 | 958 | } |
947 | 959 | } |
|
0 commit comments