@@ -628,69 +628,185 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
628628 }
629629}
630630
631+ TRITONSERVER_Error*
632+ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable (
633+ const char * model_name, bool * is_model_unavailable)
634+ {
635+ /* Use the RepositoryIndex API to check if the model state has become
636+ UNAVAILABLE i.e. model is no longer in the 'in-the-process-of' being
637+ UNLOADED. Consequently, the reason field should be 'unloaded'.*/
638+ TRITONSERVER_Message* server_model_index_message = nullptr ;
639+ uint32_t ready_flag = 0 ; // value of 1 should be set if only the 'ready'
640+ // models are required from the index. In this case,
641+ // we need all models.
642+ TRITONSERVER_ServerModelIndex (
643+ server_.get (), ready_flag, &server_model_index_message);
644+
645+ std::shared_ptr<TRITONSERVER_Message> shared_ptr_msg (
646+ server_model_index_message,
647+ [](TRITONSERVER_Message* msg) { TRITONSERVER_MessageDelete (msg); });
648+
649+ const char * index_buffer;
650+ size_t index_byte_size;
651+
652+ RETURN_IF_ERR (TRITONSERVER_MessageSerializeToJson (
653+ server_model_index_message, &index_buffer, &index_byte_size));
654+
655+ /* Read into json buffer*/
656+ triton::common::TritonJson::Value server_model_index_json;
657+ server_model_index_json.Parse (index_buffer, index_byte_size);
658+
659+ const char * name;
660+ const char * state;
661+ const char * reason;
662+ const char * version;
663+
664+ size_t name_len;
665+ size_t state_len;
666+ size_t reason_len;
667+ size_t version_len;
668+
669+ for (size_t id = 0 ; id < server_model_index_json.ArraySize (); ++id) {
670+ triton::common::TritonJson::Value index_json;
671+ server_model_index_json.IndexAsObject (id, &index_json);
672+
673+ RETURN_IF_ERR (index_json.MemberAsString (" name" , &name, &name_len));
674+
675+ if (std::string (name) == std::string (model_name)) {
676+ RETURN_IF_ERR (index_json.MemberAsString (" state" , &state, &state_len));
677+
678+ if (std::string (state) == UNLOAD_EXPECTED_STATE_) {
679+ RETURN_IF_ERR (
680+ index_json.MemberAsString (" reason" , &reason, &reason_len));
681+
682+ if (std::string (reason) == UNLOAD_EXPECTED_REASON_) {
683+ *is_model_unavailable = true ;
684+
685+ RETURN_IF_ERR (
686+ index_json.MemberAsString (" version" , &version, &version_len));
687+
688+ LOG_VERBOSE (1 ) << " Discovered model: " << name
689+ << " , version: " << version << " in state: " << state
690+ << " for the reason: " << reason;
691+
692+ break ;
693+ }
694+ }
695+ }
696+ }
697+
698+ return nullptr ;
699+ }
700+
631701void
632702SagemakerAPIServer::SageMakerMMEUnloadModel (
633703 evhtp_request_t * req, const char * model_name)
634704{
635- std::lock_guard<std::mutex> lock (mutex_);
636-
637705 if (sagemaker_models_list_.find (model_name) == sagemaker_models_list_.end ()) {
638706 LOG_VERBOSE (1 ) << " Model " << model_name << " is not loaded." << std::endl;
639707 evhtp_send_reply (req, EVHTP_RES_NOTFOUND); /* 404*/
640708 return ;
641709 }
642710
711+ /* Extract targetModel to log the associated archive */
712+ const char * targetModel =
713+ evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
714+
715+ LOG_INFO << " Unloading SageMaker TargetModel: " << targetModel << std::endl;
716+
717+ auto start_time = std::chrono::high_resolution_clock::now ();
718+
643719 /* Always unload dependents as well - this is required to unload dependents in
644720 * ensemble */
645- triton::common::TritonJson::Value request_parameters (
646- triton::common::TritonJson::ValueType::OBJECT);
647- triton::common::TritonJson::Value unload_parameter (
648- request_parameters, triton::common::TritonJson::ValueType::OBJECT);
649-
650- unload_parameter.AddBool (" unload_dependents" , true );
651- request_parameters.Add (" parameters" , std::move (unload_parameter));
721+ TRITONSERVER_Error* unload_err = nullptr ;
722+ unload_err =
723+ TRITONSERVER_ServerUnloadModelAndDependents (server_.get (), model_name);
652724
653- const char * buffer;
654- size_t byte_size;
725+ if (unload_err != nullptr ) {
726+ EVBufferAddErrorJson (req->buffer_out , unload_err);
727+ evhtp_send_reply (req, EVHTP_RES_BADREQ);
655728
656- triton::common::TritonJson::WriteBuffer json_buffer_;
657- json_buffer_. Clear ();
658- request_parameters. Write (&json_buffer_) ;
729+ LOG_ERROR
730+ << " Error when unloading SageMaker Model with dependents for model: "
731+ << model_name << std::endl ;
659732
660- byte_size = json_buffer_.Size ();
661- buffer = json_buffer_.Base ();
733+ TRITONSERVER_ErrorDelete (unload_err);
734+ return ;
735+ }
662736
663- evbuffer_add (req->buffer_in , buffer, byte_size);
737+ /* Note: Model status check is repo-specific and therefore must be run before
738+ * unregistering the repo, else the model information is lost*/
739+ bool is_model_unavailable = false ;
740+ int64_t unload_time_in_secs = 0 ;
741+
742+ /* Wait for the model to be completely unloaded. SageMaker waits a maximum
743+ of 360 seconds for the UNLOAD request to timeout. Setting a limit of 350
744+ seconds for Triton unload. This should be run only if above UNLOAD call has
745+ succeeded.*/
746+ if (unload_err == nullptr ) {
747+ LOG_VERBOSE (1 ) << " Using Model Repository Index during UNLOAD to check for "
748+ " status of model: "
749+ << model_name;
750+ while (is_model_unavailable == false &&
751+ unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
752+ LOG_VERBOSE (1 ) << " In the loop to wait for model to be unavailable" ;
753+ unload_err = SageMakerMMECheckUnloadedModelIsUnavailable (
754+ model_name, &is_model_unavailable);
755+ if (unload_err != nullptr ) {
756+ LOG_ERROR << " Error: Received non-zero exit code on checking for "
757+ " model unavailability. "
758+ << TRITONSERVER_ErrorMessage (unload_err);
759+ break ;
760+ }
761+ std::this_thread::sleep_for (
762+ std::chrono::milliseconds (UNLOAD_SLEEP_MILLISECONDS_));
664763
665- /* Extract targetModel to log the associated archive */
666- const char * targetModel =
667- evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
764+ auto end_time = std::chrono::high_resolution_clock::now ();
668765
669- LOG_INFO << " Unloading SageMaker TargetModel: " << targetModel << std::endl;
766+ unload_time_in_secs = std::chrono::duration_cast<std::chrono::seconds>(
767+ end_time - start_time)
768+ .count ();
769+ }
770+ LOG_INFO << " UNLOAD for model " << model_name << " completed in "
771+ << unload_time_in_secs << " seconds." ;
772+ TRITONSERVER_ErrorDelete (unload_err);
773+ }
670774
671- HandleRepositoryControl (req, " " , model_name, " unload" );
775+ if ((is_model_unavailable == false ) &&
776+ (unload_time_in_secs >= UNLOAD_TIMEOUT_SECS_)) {
777+ LOG_ERROR << " Error: UNLOAD did not complete within expected "
778+ << UNLOAD_TIMEOUT_SECS_
779+ << " seconds. This may "
780+ " result in SageMaker UNLOAD timeout." ;
781+ }
672782
673783 std::string repo_parent_path = sagemaker_models_list_.at (model_name);
674784
675- TRITONSERVER_Error* unload_err = TRITONSERVER_ServerUnregisterModelRepository (
785+ TRITONSERVER_Error* unregister_err = nullptr ;
786+
787+ unregister_err = TRITONSERVER_ServerUnregisterModelRepository (
676788 server_.get (), repo_parent_path.c_str ());
677789
678- if (unload_err != nullptr ) {
790+ if (unregister_err != nullptr ) {
679791 EVBufferAddErrorJson (req->buffer_out , unload_err);
680792 evhtp_send_reply (req, EVHTP_RES_BADREQ);
681793 LOG_ERROR << " Unable to unregister model repository for path: "
682794 << repo_parent_path << std::endl;
683- TRITONSERVER_ErrorDelete (unload_err);
795+ } else {
796+ evhtp_send_reply (req, EVHTP_RES_OK);
684797 }
685798
799+ TRITONSERVER_ErrorDelete (unregister_err);
800+
801+ std::lock_guard<std::mutex> lock (models_list_mutex_);
686802 sagemaker_models_list_.erase (model_name);
687803}
688804
689805void
690806SagemakerAPIServer::SageMakerMMEGetModel (
691807 evhtp_request_t * req, const char * model_name)
692808{
693- std::lock_guard<std::mutex> lock (mutex_ );
809+ std::lock_guard<std::mutex> lock (models_list_mutex_ );
694810
695811 if (sagemaker_models_list_.find (model_name) == sagemaker_models_list_.end ()) {
696812 evhtp_send_reply (req, EVHTP_RES_NOTFOUND); /* 404*/
@@ -721,7 +837,7 @@ SagemakerAPIServer::SageMakerMMEGetModel(
721837void
722838SagemakerAPIServer::SageMakerMMEListModel (evhtp_request_t * req)
723839{
724- std::lock_guard<std::mutex> lock (mutex_ );
840+ std::lock_guard<std::mutex> lock (models_list_mutex_ );
725841
726842 triton::common::TritonJson::Value sagemaker_list_json (
727843 triton::common::TritonJson::ValueType::OBJECT);
@@ -866,8 +982,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
866982 if (config_fstream.is_open ()) {
867983 ensemble_config_content << config_fstream.rdbuf ();
868984 } else {
869- continue ; // A valid config.pbtxt does not exit at this path, or cannot
870- // be read
985+ continue ; // A valid config.pbtxt does not exist at this path, or
986+ // cannot be read
871987 }
872988
873989 /* Compare matched string with `platform: "ensemble"` or
@@ -972,7 +1088,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
9721088 } else if (err != nullptr ) {
9731089 SageMakerMMEHandleOOMError (req, err);
9741090 } else {
975- std::lock_guard<std::mutex> lock (mutex_ );
1091+ std::lock_guard<std::mutex> lock (models_list_mutex_ );
9761092
9771093 sagemaker_models_list_.emplace (model_name, repo_parent_path);
9781094 evhtp_send_reply (req, EVHTP_RES_OK);
@@ -995,5 +1111,4 @@ SagemakerAPIServer::SageMakerMMELoadModel(
9951111
9961112 return ;
9971113}
998-
9991114}} // namespace triton::server
0 commit comments