Skip to content

Commit 595a488

Browse files
committed
Only check CPU shm to pass unit tests
1 parent 2f2c207 commit 595a488

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

src/pb_memory.cc

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
#include "pb_memory.h"
2828

29+
#include <sstream>
30+
2931
namespace triton { namespace backend { namespace python {
3032

3133
std::unique_ptr<PbMemory>
@@ -225,12 +227,6 @@ PbMemory::LoadFromSharedMemory(
225227
{
226228
MemoryShm* memory_shm_ptr = reinterpret_cast<MemoryShm*>(data_shm);
227229
char* memory_data_shm = data_shm + sizeof(MemoryShm);
228-
229-
if (memory_data_shm + memory_shm_ptr->byte_size >
230-
(char*)shm_pool->GetBaseAddress() + shm_pool->GetCurrentCapacity()) {
231-
throw PythonBackendException("Attempted to access out of bounds memory.");
232-
}
233-
234230
char* data_ptr = nullptr;
235231
bool opened_cuda_ipc_handle = false;
236232
if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU &&
@@ -265,6 +261,19 @@ PbMemory::LoadFromSharedMemory(
265261
} else {
266262
data_ptr = memory_data_shm;
267263
}
264+
265+
// This check only validates CPU shared memory access.
266+
if (memory_shm_ptr->memory_type != TRITONSERVER_MEMORY_GPU &&
267+
(data_ptr + memory_shm_ptr->byte_size >
268+
(char*)shm_pool->GetBaseAddress() + shm_pool->GetCurrentCapacity())) {
269+
std::ostringstream oss;
270+
oss << "0x" << std::hex
271+
<< (reinterpret_cast<uintptr_t>(data_ptr) + memory_shm_ptr->byte_size);
272+
throw PythonBackendException(
273+
std::string("Attempted to access out of bounds memory address ") +
274+
oss.str());
275+
}
276+
268277
return std::unique_ptr<PbMemory>(new PbMemory(
269278
data_shm, data_ptr, handle,
270279
opened_cuda_ipc_handle /* opened_cuda_ipc_handle */));
@@ -280,11 +289,6 @@ PbMemory::LoadFromSharedMemory(
280289
reinterpret_cast<MemoryShm*>(memory_shm.data_.get());
281290
char* memory_data_shm = memory_shm.data_.get() + sizeof(MemoryShm);
282291

283-
if (memory_data_shm + memory_shm_ptr->byte_size >
284-
(char*)shm_pool->GetBaseAddress() + shm_pool->GetCurrentCapacity()) {
285-
throw PythonBackendException("Attempted to access out of bounds memory.");
286-
}
287-
288292
char* data_ptr = nullptr;
289293
bool opened_cuda_ipc_handle = false;
290294
if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU) {
@@ -319,6 +323,19 @@ PbMemory::LoadFromSharedMemory(
319323
} else {
320324
data_ptr = memory_data_shm;
321325
}
326+
327+
// This check only validates CPU shared memory access.
328+
if (memory_shm_ptr->memory_type != TRITONSERVER_MEMORY_GPU &&
329+
(data_ptr + memory_shm_ptr->byte_size >
330+
(char*)shm_pool->GetBaseAddress() + shm_pool->GetCurrentCapacity())) {
331+
std::ostringstream oss;
332+
oss << "0x" << std::hex
333+
<< (reinterpret_cast<uintptr_t>(data_ptr) + memory_shm_ptr->byte_size);
334+
throw PythonBackendException(
335+
std::string("Attempted to access out of bounds memory address ") +
336+
oss.str());
337+
}
338+
322339
return std::unique_ptr<PbMemory>(new PbMemory(
323340
memory_shm, data_ptr,
324341
opened_cuda_ipc_handle /* opened_cuda_ipc_handle */));

0 commit comments

Comments
 (0)