@@ -98,6 +98,7 @@ class ModelState : public BackendModel {
9898 {
9999 return enable_nvfuser_pair_;
100100 }
101+ bool EnabledCacheCleaning (){ return enable_cache_cleaning_; }
101102
102103 bool EnabledWeightSharing () { return enable_weight_sharing_; }
103104
@@ -114,6 +115,9 @@ class ModelState : public BackendModel {
114115 // Flag to indicate whether inference mode is enabled. Defaults to false.
115116 bool enable_inference_mode_;
116117
118+ // Flag to indicate whether cache clearning after each run is enabled. Defaults to false.
119+ bool enable_cache_cleaning_;
120+
117121 // Flag to indicate whether weight sharing is enabled. Defaults to false.
118122 bool enable_weight_sharing_;
119123
@@ -173,7 +177,8 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
173177
174178ModelState::ModelState (TRITONBACKEND_Model* triton_model)
175179 : BackendModel(triton_model), enable_optimized_execution_(true ),
176- enable_inference_mode_ (false ), enable_weight_sharing_(false ),
180+ enable_inference_mode_ (false ), enable_cache_cleaning_(false ),
181+ enable_weight_sharing_(false ),
177182 enable_tensor_fuser_pair_({false , true }),
178183 enable_jit_profiling_pair_({false , true }),
179184 enable_jit_executor_pair_({false , true }),
@@ -298,6 +303,25 @@ ModelState::ParseParameters()
298303 " for model instance '" + Name () + " '" )
299304 .c_str ());
300305
306+ // If 'ENABLE_CACHE_CLEANING' is not present in 'parameters' then
307+ // no update is made to 'enable_cache_cleaning_'.
308+ err = ParseParameter (
309+ params, " ENABLE_CACHE_CLEANING" , &enable_cache_cleaning_);
310+ if (err != nullptr ) {
311+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
312+ return err;
313+ } else {
314+ TRITONSERVER_ErrorDelete (err);
315+ }
316+ }
317+
318+ LOG_MESSAGE (
319+ TRITONSERVER_LOG_INFO,
320+ (std::string (" Cache Cleaning is " ) +
321+ (enable_cache_cleaning_ ? " enabled" : " disabled" ) +
322+ " for model instance '" + Name () + " '" )
323+ .c_str ());
324+
301325 // If 'INFERENCE_MODE' is not present in 'parameters' then no update is made
302326 // to 'enable_inference_mode_'.
303327 err = ParseParameter (params, " INFERENCE_MODE" , &enable_inference_mode_);
@@ -453,6 +477,9 @@ class ModelInstanceState : public BackendModelInstance {
453477 void ProcessRequests (
454478 TRITONBACKEND_Request** requests, const uint32_t request_count);
455479
480+ // Clear CUDA cache
481+ void ClearCache ();
482+
456483 private:
457484 ModelInstanceState (
458485 ModelState* model_state,
@@ -585,16 +612,21 @@ ModelInstanceState::ModelInstanceState(
585612 THROW_IF_BACKEND_INSTANCE_ERROR (ValidateOutputs ());
586613}
587614
588- ModelInstanceState::~ModelInstanceState ()
615+ void ModelInstanceState::ClearCache ()
589616{
590- torch_model_.reset ();
591617#ifdef TRITON_ENABLE_GPU
592618 if (device_.is_cuda ()) {
593619 c10::cuda::CUDACachingAllocator::emptyCache ();
594620 }
595621#endif // TRITON_ENABLE_GPU
596622}
597623
624+ ModelInstanceState::~ModelInstanceState ()
625+ {
626+ torch_model_.reset ();
627+ ClearCache ();
628+ }
629+
598630TRITONSERVER_Error*
599631ModelInstanceState::ValidateBooleanSequenceControl (
600632 triton::common::TritonJson::Value& sequence_batching,
@@ -2081,6 +2113,10 @@ TRITONBACKEND_ModelInstanceExecute(
20812113 // specific request.
20822114 instance_state->ProcessRequests (requests, request_count);
20832115
2116+ if (model_state->EnabledCacheCleaning ()) {
2117+ instance_state->ClearCache ();
2118+ }
2119+
20842120 return nullptr ; // success
20852121}
20862122
0 commit comments