Skip to content

Commit f1aedd4

Browse files
authored
Add timed wait during UNLOAD while the model becomes UNAVAILABLE in SageMaker (#5423)
* Add timed wait during UNLOAD while the model becomes UNAVAILABLE in SageMaker * Directly use C API to UNLOAD model in SM * Address comments and bug fixes * Add logging for model server index * Change MME model repo * Address comments and use chrono seconds, don't repeat error assignment * Address minor comments * Fix typo in log * Update minor comment
1 parent b601bd8 commit f1aedd4

File tree

3 files changed

+167
-34
lines changed

3 files changed

+167
-34
lines changed

docker/sagemaker/serve

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,21 @@
2626
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

2828
SAGEMAKER_SINGLE_MODEL_REPO=/opt/ml/model/
29-
SAGEMAKER_MULTI_MODEL_REPO=/opt/ml/models/
29+
30+
# Note: in Triton on SageMaker, each model url is registered as a separate repository
31+
# e.g., /opt/ml/models/<hash>/model. Specifying MME model repo path as /opt/ml/models causes Triton
32+
# to treat it as an additional empty repository and changes
33+
# the state of all models to be UNAVAILABLE in the model repository
34+
# https://github.com/triton-inference-server/core/blob/main/src/model_repository_manager.cc#L914,L922
35+
# On Triton, this path will be a dummy path as it's mandatory to specify a model repo when starting triton
36+
SAGEMAKER_MULTI_MODEL_REPO=/tmp/sagemaker
3037

3138
SAGEMAKER_MODEL_REPO=${SAGEMAKER_SINGLE_MODEL_REPO}
3239
is_mme_mode=false
3340

3441
if [ -n "$SAGEMAKER_MULTI_MODEL" ]; then
3542
if [ "$SAGEMAKER_MULTI_MODEL" == "true" ]; then
43+
mkdir -p ${SAGEMAKER_MULTI_MODEL_REPO}
3644
SAGEMAKER_MODEL_REPO=${SAGEMAKER_MULTI_MODEL_REPO}
3745
is_mme_mode=true
3846
echo "Triton is running in SageMaker MME mode."

src/sagemaker_server.cc

Lines changed: 146 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -628,69 +628,185 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
628628
}
629629
}
630630

631+
TRITONSERVER_Error*
632+
SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
633+
const char* model_name, bool* is_model_unavailable)
634+
{
635+
/* Use the RepositoryIndex API to check if the model state has become
636+
UNAVAILABLE i.e. model is no longer in the 'in-the-process-of' being
637+
UNLOADED. Consequently, the reason field should be 'unloaded'.*/
638+
TRITONSERVER_Message* server_model_index_message = nullptr;
639+
uint32_t ready_flag = 0; // value of 1 should be set if only the 'ready'
640+
// models are required from the index. In this case,
641+
// we need all models.
642+
TRITONSERVER_ServerModelIndex(
643+
server_.get(), ready_flag, &server_model_index_message);
644+
645+
std::shared_ptr<TRITONSERVER_Message> shared_ptr_msg(
646+
server_model_index_message,
647+
[](TRITONSERVER_Message* msg) { TRITONSERVER_MessageDelete(msg); });
648+
649+
const char* index_buffer;
650+
size_t index_byte_size;
651+
652+
RETURN_IF_ERR(TRITONSERVER_MessageSerializeToJson(
653+
server_model_index_message, &index_buffer, &index_byte_size));
654+
655+
/* Read into json buffer*/
656+
triton::common::TritonJson::Value server_model_index_json;
657+
server_model_index_json.Parse(index_buffer, index_byte_size);
658+
659+
const char* name;
660+
const char* state;
661+
const char* reason;
662+
const char* version;
663+
664+
size_t name_len;
665+
size_t state_len;
666+
size_t reason_len;
667+
size_t version_len;
668+
669+
for (size_t id = 0; id < server_model_index_json.ArraySize(); ++id) {
670+
triton::common::TritonJson::Value index_json;
671+
server_model_index_json.IndexAsObject(id, &index_json);
672+
673+
RETURN_IF_ERR(index_json.MemberAsString("name", &name, &name_len));
674+
675+
if (std::string(name) == std::string(model_name)) {
676+
RETURN_IF_ERR(index_json.MemberAsString("state", &state, &state_len));
677+
678+
if (std::string(state) == UNLOAD_EXPECTED_STATE_) {
679+
RETURN_IF_ERR(
680+
index_json.MemberAsString("reason", &reason, &reason_len));
681+
682+
if (std::string(reason) == UNLOAD_EXPECTED_REASON_) {
683+
*is_model_unavailable = true;
684+
685+
RETURN_IF_ERR(
686+
index_json.MemberAsString("version", &version, &version_len));
687+
688+
LOG_VERBOSE(1) << "Discovered model: " << name
689+
<< ", version: " << version << " in state: " << state
690+
<< "for the reason: " << reason;
691+
692+
break;
693+
}
694+
}
695+
}
696+
}
697+
698+
return nullptr;
699+
}
700+
631701
void
632702
SagemakerAPIServer::SageMakerMMEUnloadModel(
633703
evhtp_request_t* req, const char* model_name)
634704
{
635-
std::lock_guard<std::mutex> lock(mutex_);
636-
637705
if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
638706
LOG_VERBOSE(1) << "Model " << model_name << " is not loaded." << std::endl;
639707
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
640708
return;
641709
}
642710

711+
/* Extract targetModel to log the associated archive */
712+
const char* targetModel =
713+
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
714+
715+
LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl;
716+
717+
auto start_time = std::chrono::high_resolution_clock::now();
718+
643719
/* Always unload dependents as well - this is required to unload dependents in
644720
* ensemble */
645-
triton::common::TritonJson::Value request_parameters(
646-
triton::common::TritonJson::ValueType::OBJECT);
647-
triton::common::TritonJson::Value unload_parameter(
648-
request_parameters, triton::common::TritonJson::ValueType::OBJECT);
649-
650-
unload_parameter.AddBool("unload_dependents", true);
651-
request_parameters.Add("parameters", std::move(unload_parameter));
721+
TRITONSERVER_Error* unload_err = nullptr;
722+
unload_err =
723+
TRITONSERVER_ServerUnloadModelAndDependents(server_.get(), model_name);
652724

653-
const char* buffer;
654-
size_t byte_size;
725+
if (unload_err != nullptr) {
726+
EVBufferAddErrorJson(req->buffer_out, unload_err);
727+
evhtp_send_reply(req, EVHTP_RES_BADREQ);
655728

656-
triton::common::TritonJson::WriteBuffer json_buffer_;
657-
json_buffer_.Clear();
658-
request_parameters.Write(&json_buffer_);
729+
LOG_ERROR
730+
<< "Error when unloading SageMaker Model with dependents for model: "
731+
<< model_name << std::endl;
659732

660-
byte_size = json_buffer_.Size();
661-
buffer = json_buffer_.Base();
733+
TRITONSERVER_ErrorDelete(unload_err);
734+
return;
735+
}
662736

663-
evbuffer_add(req->buffer_in, buffer, byte_size);
737+
/*Note: Model status check is repo-specific and therefore must be run before
738+
* unregistering the repo, else the model information is lost*/
739+
bool is_model_unavailable = false;
740+
int64_t unload_time_in_secs = 0;
741+
742+
/* Wait for the model to be completely unloaded. SageMaker waits a maximum
743+
of 360 seconds for the UNLOAD request to timeout. Setting a limit of 350
744+
seconds for Triton unload. This should be run only if above UNLOAD call has
745+
succeeded.*/
746+
if (unload_err == nullptr) {
747+
LOG_VERBOSE(1) << "Using Model Repository Index during UNLOAD to check for "
748+
"status of model: "
749+
<< model_name;
750+
while (is_model_unavailable == false &&
751+
unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
752+
LOG_VERBOSE(1) << "In the loop to wait for model to be unavailable";
753+
unload_err = SageMakerMMECheckUnloadedModelIsUnavailable(
754+
model_name, &is_model_unavailable);
755+
if (unload_err != nullptr) {
756+
LOG_ERROR << "Error: Received non-zero exit code on checking for "
757+
"model unavailability. "
758+
<< TRITONSERVER_ErrorMessage(unload_err);
759+
break;
760+
}
761+
std::this_thread::sleep_for(
762+
std::chrono::milliseconds(UNLOAD_SLEEP_MILLISECONDS_));
664763

665-
/* Extract targetModel to log the associated archive */
666-
const char* targetModel =
667-
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
764+
auto end_time = std::chrono::high_resolution_clock::now();
668765

669-
LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl;
766+
unload_time_in_secs = std::chrono::duration_cast<std::chrono::seconds>(
767+
end_time - start_time)
768+
.count();
769+
}
770+
LOG_INFO << "UNLOAD for model " << model_name << " completed in "
771+
<< unload_time_in_secs << " seconds.";
772+
TRITONSERVER_ErrorDelete(unload_err);
773+
}
670774

671-
HandleRepositoryControl(req, "", model_name, "unload");
775+
if ((is_model_unavailable == false) &&
776+
(unload_time_in_secs >= UNLOAD_TIMEOUT_SECS_)) {
777+
LOG_ERROR << "Error: UNLOAD did not complete within expected "
778+
<< UNLOAD_TIMEOUT_SECS_
779+
<< " seconds. This may "
780+
"result in SageMaker UNLOAD timeout.";
781+
}
672782

673783
std::string repo_parent_path = sagemaker_models_list_.at(model_name);
674784

675-
TRITONSERVER_Error* unload_err = TRITONSERVER_ServerUnregisterModelRepository(
785+
TRITONSERVER_Error* unregister_err = nullptr;
786+
787+
unregister_err = TRITONSERVER_ServerUnregisterModelRepository(
676788
server_.get(), repo_parent_path.c_str());
677789

678-
if (unload_err != nullptr) {
790+
if (unregister_err != nullptr) {
679791
EVBufferAddErrorJson(req->buffer_out, unload_err);
680792
evhtp_send_reply(req, EVHTP_RES_BADREQ);
681793
LOG_ERROR << "Unable to unregister model repository for path: "
682794
<< repo_parent_path << std::endl;
683-
TRITONSERVER_ErrorDelete(unload_err);
795+
} else {
796+
evhtp_send_reply(req, EVHTP_RES_OK);
684797
}
685798

799+
TRITONSERVER_ErrorDelete(unregister_err);
800+
801+
std::lock_guard<std::mutex> lock(models_list_mutex_);
686802
sagemaker_models_list_.erase(model_name);
687803
}
688804

689805
void
690806
SagemakerAPIServer::SageMakerMMEGetModel(
691807
evhtp_request_t* req, const char* model_name)
692808
{
693-
std::lock_guard<std::mutex> lock(mutex_);
809+
std::lock_guard<std::mutex> lock(models_list_mutex_);
694810

695811
if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
696812
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
@@ -721,7 +837,7 @@ SagemakerAPIServer::SageMakerMMEGetModel(
721837
void
722838
SagemakerAPIServer::SageMakerMMEListModel(evhtp_request_t* req)
723839
{
724-
std::lock_guard<std::mutex> lock(mutex_);
840+
std::lock_guard<std::mutex> lock(models_list_mutex_);
725841

726842
triton::common::TritonJson::Value sagemaker_list_json(
727843
triton::common::TritonJson::ValueType::OBJECT);
@@ -866,8 +982,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
866982
if (config_fstream.is_open()) {
867983
ensemble_config_content << config_fstream.rdbuf();
868984
} else {
869-
continue; // A valid config.pbtxt does not exit at this path, or cannot
870-
// be read
985+
continue; // A valid config.pbtxt does not exist at this path, or
986+
// cannot be read
871987
}
872988

873989
/* Compare matched string with `platform: "ensemble"` or
@@ -972,7 +1088,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
9721088
} else if (err != nullptr) {
9731089
SageMakerMMEHandleOOMError(req, err);
9741090
} else {
975-
std::lock_guard<std::mutex> lock(mutex_);
1091+
std::lock_guard<std::mutex> lock(models_list_mutex_);
9761092

9771093
sagemaker_models_list_.emplace(model_name, repo_parent_path);
9781094
evhtp_send_reply(req, EVHTP_RES_OK);
@@ -995,5 +1111,4 @@ SagemakerAPIServer::SageMakerMMELoadModel(
9951111

9961112
return;
9971113
}
998-
9991114
}} // namespace triton::server

src/sagemaker_server.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -26,6 +26,7 @@
2626
#pragma once
2727

2828
#include <sys/stat.h>
29+
2930
#include <fstream>
3031
#include <mutex>
3132

@@ -105,6 +106,9 @@ class SagemakerAPIServer : public HTTPAPIServer {
105106

106107
void SageMakerMMEUnloadModel(evhtp_request_t* req, const char* model_name);
107108

109+
TRITONSERVER_Error* SageMakerMMECheckUnloadedModelIsUnavailable(
110+
const char* model_name, bool* is_model_unavailable);
111+
108112
void SageMakerMMEListModel(evhtp_request_t* req);
109113

110114
void SageMakerMMEGetModel(evhtp_request_t* req, const char* model_name);
@@ -155,7 +159,13 @@ class SagemakerAPIServer : public HTTPAPIServer {
155159
std::unordered_map<std::string, std::string> sagemaker_models_list_;
156160

157161
/* Mutex to handle concurrent updates */
158-
std::mutex mutex_;
162+
std::mutex models_list_mutex_;
163+
164+
/* Constants */
165+
const uint32_t UNLOAD_TIMEOUT_SECS_ = 350;
166+
const uint32_t UNLOAD_SLEEP_MILLISECONDS_ = 500;
167+
const std::string UNLOAD_EXPECTED_STATE_ = "UNAVAILABLE";
168+
const std::string UNLOAD_EXPECTED_REASON_ = "unloaded";
159169
};
160170

161171
}} // namespace triton::server

0 commit comments

Comments
 (0)