Skip to content

Commit 21245cf

Browse files
authored
fix: Improve validation for system shared memory register (#8336)
1 parent 50ff906 commit 21245cf

File tree

5 files changed

+83
-16
lines changed

5 files changed

+83
-16
lines changed

qa/L0_shared_memory/shared_memory_test.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -566,16 +566,45 @@ def test_register_reserved_names(self):
566566
"""
567567
# This matches kTritonSharedMemoryRegionPrefix in the server code.
568568
reserved_prefix = "triton_python_backend_shm_region_"
569+
shm_name = "my_test_shm_name"
570+
571+
# The shared memory key cannot start with the reserved prefix,
572+
# regardless of leading slashes.
573+
shm_keys_to_test = [
574+
f"{reserved_prefix}_my_test_shm_key",
575+
f"/{reserved_prefix}_my_test_shm_key",
576+
f"///{reserved_prefix}_my_test_shm_key",
577+
]
569578

570-
# The shared memory key cannot start with the reserved prefix.
579+
for shm_key in shm_keys_to_test:
580+
with self.subTest(shm_key=shm_key):
581+
expected_msg = f"cannot register shared memory region '{shm_name}' with key '{shm_key}' as the key contains the reserved prefix '{reserved_prefix}'"
582+
with self.assertRaisesRegex(
583+
utils.InferenceServerException, expected_msg
584+
):
585+
self.triton_client.register_system_shared_memory(
586+
shm_name, shm_key, 10000
587+
)
588+
589+
def test_register_invalid_shm_key(self):
590+
"""
591+
Test that registration fails if attempting to use an invalid name for the shm key.
592+
"""
571593
shm_name = "my_test_shm_name"
572-
shm_key = f"{reserved_prefix}_my_test_shm_key"
594+
shm_keys_to_test = [
595+
"/",
596+
"///",
597+
]
573598

574-
with self.assertRaisesRegex(
575-
utils.InferenceServerException,
576-
f"cannot register shared memory region '{shm_name}' with key '{shm_key}' as the key contains the reserved prefix '{reserved_prefix}'",
577-
) as e:
578-
self.triton_client.register_system_shared_memory(shm_name, shm_key, 10000)
599+
for shm_key in shm_keys_to_test:
600+
with self.subTest(shm_key=shm_key):
601+
expected_msg = f"cannot register shared memory region '{shm_name}' - invalid shm key '{shm_key}'"
602+
with self.assertRaisesRegex(
603+
utils.InferenceServerException, expected_msg
604+
):
605+
self.triton_client.register_system_shared_memory(
606+
shm_name, shm_key, 10000
607+
)
579608

580609

581610
def callback(user_data, result, error):

qa/L0_shared_memory/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ for i in \
7070
test_infer_integer_overflow \
7171
test_register_out_of_bound \
7272
test_register_reserved_names \
73+
test_register_invalid_shm_key \
7374
test_python_client_leak; do
7475
for client_type in http grpc; do
7576
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1 ${SERVER_ARGS_EXTRA}"

src/common.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,39 @@ DecodeBase64(
163163
return nullptr;
164164
}
165165

166+
TRITONSERVER_Error*
167+
ValidateSharedMemoryKey(const std::string& name, const std::string& shm_key)
168+
{
169+
std::string_view key_view(shm_key);
170+
171+
// Find the index of the first character that is not a slash
172+
const std::size_t first_non_slash = key_view.find_first_not_of('/');
173+
174+
// If the entire key is slashes
175+
if (first_non_slash == std::string_view::npos) {
176+
return TRITONSERVER_ErrorNew(
177+
TRITONSERVER_ERROR_INVALID_ARG,
178+
std::string(
179+
"cannot register shared memory region '" + name +
180+
"' - invalid shm key '" + shm_key + "'")
181+
.c_str());
182+
}
183+
184+
// Check whether the substring starting at first_non_slash starts with the
185+
// reserved prefix
186+
if (key_view.substr(first_non_slash)
187+
.rfind(kTritonSharedMemoryRegionPrefix, 0) == 0) {
188+
return TRITONSERVER_ErrorNew(
189+
TRITONSERVER_ERROR_INVALID_ARG,
190+
std::string(
191+
"cannot register shared memory region '" + name + "' with key '" +
192+
shm_key + "' as the key contains the reserved prefix '" +
193+
kTritonSharedMemoryRegionPrefix + "'")
194+
.c_str());
195+
}
196+
197+
// Valid shm key
198+
return nullptr;
199+
}
200+
166201
}} // namespace triton::server

src/common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,16 @@ TRITONSERVER_Error* DecodeBase64(
196196
const char* input, size_t input_len, std::vector<char>& decoded_data,
197197
size_t& decoded_size, const std::string& name);
198198

199+
200+
/// Validate shared memory key
201+
///
202+
/// \param name The name of the memory block.
203+
/// \param shm_key The name of the posix shared memory object
204+
/// \return The error status.
205+
TRITONSERVER_Error* ValidateSharedMemoryKey(
206+
const std::string& name, const std::string& shm_key);
207+
208+
199209
/// Joins container of strings into a single string delimited by
200210
/// 'delim'.
201211
///

src/shared_memory_manager.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -355,15 +355,7 @@ SharedMemoryManager::RegisterSystemSharedMemory(
355355
const size_t byte_size)
356356
{
357357
// Check if the shared memory key starts with the reserved prefix
358-
if (shm_key.rfind(kTritonSharedMemoryRegionPrefix, 0) == 0) {
359-
return TRITONSERVER_ErrorNew(
360-
TRITONSERVER_ERROR_INVALID_ARG,
361-
std::string(
362-
"cannot register shared memory region '" + name + "' with key '" +
363-
shm_key + "' as the key contains the reserved prefix '" +
364-
kTritonSharedMemoryRegionPrefix + "'")
365-
.c_str());
366-
}
358+
RETURN_IF_ERR(ValidateSharedMemoryKey(name, shm_key));
367359

368360
std::lock_guard<std::mutex> lock(mu_);
369361

0 commit comments

Comments
 (0)