@@ -61,6 +61,10 @@ class RequestTracker {
6161
6262 std::unique_ptr<InferenceRequest>& Request () { return request_; }
6363
64+ InferenceStatsAggregator* StatsAggregator () { return stats_aggregator_; }
65+
66+ MetricModelReporter* MetricReporter () { return metric_reporter_; }
67+
6468 InferenceStatsAggregator& ContextStatsAggregator ()
6569 {
6670 return context_stats_aggregator_;
@@ -316,6 +320,9 @@ class EnsembleContext {
316320 const std::set<std::pair<std::string, IterationCount>>& updated_tensors,
317321 std::unique_ptr<InferenceResponse>* response);
318322
323+ void CacheEnsembleTopLevelRequest (
324+ std::unique_ptr<InferenceResponse>& response);
325+
319326 InferenceServer* is_;
320327
321328 EnsembleInfo* info_;
@@ -1033,6 +1040,50 @@ EnsembleContext::ReshapeTensorDims(
10331040 return res;
10341041}
10351042
1043+ // Caching function
1044+ void
1045+ EnsembleContext::CacheEnsembleTopLevelRequest (
1046+ std::unique_ptr<InferenceResponse>& response)
1047+ {
1048+ const std::string key = request_tracker_->Request ()->CacheKey ();
1049+ const bool is_key_set = request_tracker_->Request ()->CacheKeyIsSet ();
1050+
1051+ #ifdef TRITON_ENABLE_STATS
1052+ const uint64_t lookup_end_ns =
1053+ request_tracker_->Request ()->CacheLookupEndNs ();
1054+ const uint64_t lookup_start_ns =
1055+ request_tracker_->Request ()->CacheLookupStartNs ();
1056+ #endif
1057+
1058+ if (!is_key_set) {
1059+ LOG_ERROR << " Request cache key was not set correctly." ;
1060+ }
1061+
1062+ auto cache = is_->CacheManager ()->Cache ();
1063+ #ifdef TRITON_ENABLE_STATS
1064+ const uint64_t insert_start_ns = CaptureTimeNs ();
1065+ #endif
1066+ auto status = cache->Insert (response.get (), key);
1067+ if (!status.IsOk ()) {
1068+ LOG_ERROR << " Failed to insert key [" << key
1069+ << " ] into response cache: " << status.Message ();
1070+ }
1071+
1072+ #ifdef TRITON_ENABLE_STATS
1073+ const uint64_t insert_end_ns = CaptureTimeNs ();
1074+ uint64_t lookup_ns = lookup_end_ns - lookup_start_ns;
1075+ if (lookup_start_ns > lookup_end_ns) {
1076+ lookup_ns = 0 ;
1077+ LOG_ERROR << " Request lookup duration was not set correctly." ;
1078+ }
1079+ uint64_t insert_ns = insert_end_ns - insert_start_ns;
1080+ uint64_t cache_miss_ns = lookup_ns + insert_ns;
1081+ request_tracker_->StatsAggregator ()->UpdateSuccessCacheMiss (
1082+ request_tracker_->MetricReporter (), cache_miss_ns);
1083+ #endif
1084+ }
1085+
1086+
10361087Status
10371088EnsembleContext::FinishEnsemble (std::unique_ptr<InferenceResponse>&& response)
10381089{
@@ -1053,6 +1104,10 @@ EnsembleContext::FinishEnsemble(std::unique_ptr<InferenceResponse>&& response)
10531104 ? TRITONSERVER_RESPONSE_COMPLETE_FINAL
10541105 : 0 ;
10551106 if (response != nullptr ) {
1107+ // Cache the request if caching is enabled.
1108+ if (info_->is_cache_enabled_ ) {
1109+ CacheEnsembleTopLevelRequest (response);
1110+ }
10561111 InferenceResponse::Send (std::move (response), flags);
10571112 response_sent_ = true ;
10581113 } else if (flags != 0 ) {
@@ -1319,6 +1374,21 @@ EnsembleScheduler::Create(
13191374 return Status::Success;
13201375}
13211376
1377+
1378+ void
1379+ EnsembleScheduler::CacheLookUp (
1380+ std::unique_ptr<InferenceRequest>& request,
1381+ std::unique_ptr<InferenceResponse>& cached_response)
1382+ {
1383+ auto cache = is_->CacheManager ()->Cache ();
1384+ bool is_lookup_success = CacheLookUpUtil (request, cached_response, cache);
1385+ if (is_lookup_success) {
1386+ #ifdef TRITON_ENABLE_STATS
1387+ request->ReportStatisticsCacheHit (metric_reporter_.get ());
1388+ #endif
1389+ }
1390+ }
1391+
13221392Status
13231393EnsembleScheduler::Enqueue (std::unique_ptr<InferenceRequest>& request)
13241394{
@@ -1333,6 +1403,19 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
13331403 TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT, " EnsembleScheduler Enqueue" );
13341404#endif // TRITON_ENABLE_TRACING
13351405
1406+ std::unique_ptr<InferenceResponse> cached_response;
1407+ if (info_->is_cache_enabled_ ) {
1408+ CacheLookUp (request, cached_response);
1409+ }
1410+
1411+ if (cached_response != nullptr ) {
1412+ InferenceResponse::Send (
1413+ std::move (cached_response), TRITONSERVER_RESPONSE_COMPLETE_FINAL);
1414+ InferenceRequest::Release (
1415+ std::move (request), TRITONSERVER_REQUEST_RELEASE_ALL);
1416+ return Status::Success;
1417+ }
1418+
13361419 // Add additional callback to keep track of in-flight count
13371420 ++inflight_count_;
13381421 request->AddInternalReleaseCallback (
@@ -1387,6 +1470,10 @@ EnsembleScheduler::EnsembleScheduler(
13871470 // This config field is filled internally for ensemble models
13881471 info_->is_decoupled_ = config.model_transaction_policy ().decoupled ();
13891472
1473+ // field to check if response cache enabled in the ensemble model config.
1474+ info_->is_cache_enabled_ =
1475+ config.response_cache ().enable () && is_->ResponseCacheEnabled ();
1476+
13901477 for (const auto & input : config.input ()) {
13911478 info_->tensor_to_step_ .emplace (input.name (), std::set<size_t >());
13921479 if (input.optional ()) {
0 commit comments