Skip to content

Commit 9a9815e

Browse files
committed
Update
1 parent a4ed606 commit 9a9815e

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

src/ensemble_scheduler/ensemble_scheduler.cc

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "ensemble_scheduler.h"
3030

31+
#include <condition_variable>
3132
#include <mutex>
3233

3334
#include "cuda_utils.h"
@@ -370,6 +371,13 @@ class EnsembleContext {
370371

371372
size_t inflight_step_counter_;
372373

374+
// Backpressure support: Limits memory growth from decoupled models.
375+
// Tracks inflight responses per step; blocks producers when downstream
376+
// consumers are overloaded. Only active if max_inflight_responses_ > 0.
377+
std::vector<size_t> step_inflight_response_counts_;
378+
std::vector<std::unique_ptr<std::mutex>> step_mutexes_;
379+
std::vector<std::unique_ptr<std::condition_variable>> step_cvs_;
380+
373381
// pointer that either points to 'pruned_tensor_to_step_' or to
374382
// 'info_->tensor_to_step_' if all ensemble outputs are requested
375383
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
@@ -505,6 +513,20 @@ EnsembleContext::EnsembleContext(
505513
}
506514
}
507515

516+
// Initialize backpressure tracking if enabled.
517+
size_t num_steps = info_->steps_.size();
518+
step_inflight_response_counts_.resize(num_steps, 0);
519+
520+
if (info_->max_inflight_responses_ > 0) {
521+
step_mutexes_.resize(num_steps);
522+
step_cvs_.resize(num_steps);
523+
524+
for (size_t i = 0; i < num_steps; i++) {
525+
step_mutexes_[i].reset(new std::mutex());
526+
step_cvs_[i].reset(new std::condition_variable());
527+
}
528+
}
529+
508530
if (ensemble_status_.IsOk()) {
509531
request_id_ = lrequest->Id();
510532
correlation_id_ = lrequest->CorrelationId();
@@ -669,6 +691,46 @@ EnsembleContext::ResponseComplete(
669691
auto pool = step_raw_ptr->ctx_->CallbackPool();
670692
auto fn = [response, flags, step_raw_ptr]() {
671693
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
694+
auto& context = step_ptr->ctx_;
695+
size_t this_step_idx = step_ptr->step_idx_;
696+
const auto& istep = context->info_->steps_[this_step_idx];
697+
698+
// Block this producer if downstream consumers are overloaded.
699+
// Prevents memory exhaustion by limiting concurrent inflight responses.
700+
if (context->info_->max_inflight_responses_ > 0 &&
701+
!context->step_cvs_.empty()) {
702+
for (const auto& output_pair : istep.output_to_tensor_) {
703+
const auto& tensor_name = output_pair.second;
704+
const auto& downstream_steps = (*context->tensor_to_step_)[tensor_name];
705+
706+
for (const auto& downstream_step_idx : downstream_steps) {
707+
std::unique_lock<std::mutex> lk(
708+
*context->step_mutexes_[downstream_step_idx]);
709+
710+
// Block if downstream inflight count >= limit. Timeout after 300s to
711+
// prevent any deadlock. Unblocks when downstream completes a request.
712+
auto timeout = std::chrono::seconds(300);
713+
bool capacity_available =
714+
context->step_cvs_[downstream_step_idx]->wait_for(
715+
lk, timeout, [&] {
716+
return context->step_inflight_response_counts_
717+
[downstream_step_idx] <
718+
context->info_->max_inflight_responses_;
719+
});
720+
721+
if (!capacity_available) {
722+
LOG_ERROR
723+
<< "[Internal Error] Ensemble '"
724+
<< context->info_->ensemble_name_ << "' step " << this_step_idx
725+
<< " blocked waiting for downstream step "
726+
<< downstream_step_idx << " (inflight: "
727+
<< context->step_inflight_response_counts_[downstream_step_idx]
728+
<< " >= limit: " << context->info_->max_inflight_responses_
729+
<< ") for 300 seconds. Proceeding to avoid deadlock.";
730+
}
731+
}
732+
}
733+
}
672734
step_ptr->response_flags_ = flags;
673735
step_ptr->response_ = response;
674736

@@ -907,6 +969,15 @@ EnsembleContext::UpdateEnsembleState(
907969
if (completed_step->response_flags_ &
908970
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
909971
inflight_step_counter_--;
972+
973+
size_t completed_step_idx = completed_step->step_idx_;
974+
step_inflight_response_counts_[completed_step_idx]--;
975+
976+
// Notify any producer threads blocked waiting for this step's capacity
977+
if (info_->max_inflight_responses_ > 0 && !step_cvs_.empty()) {
978+
std::lock_guard<std::mutex> lk(*step_mutexes_[completed_step_idx]);
979+
step_cvs_[completed_step_idx]->notify_one();
980+
}
910981
}
911982
RETURN_IF_ERROR(ConsumeResponse(completed_step));
912983
updated_tensors->swap(completed_step->updated_tensors_);
@@ -950,6 +1021,10 @@ EnsembleContext::GetNextSteps(
9501021
for (const auto& idx : next_step_idx) {
9511022
steps->emplace_back();
9521023
RETURN_IF_ERROR(InitStep(idx.first, idx.second, &(steps->back())));
1024+
1025+
// Track as inflight. Checked by producers for backpressure; decremented on
1026+
// completion.
1027+
step_inflight_response_counts_[idx.first]++;
9531028
}
9541029
inflight_step_counter_ += steps->size();
9551030

@@ -1602,6 +1677,32 @@ EnsembleScheduler::EnsembleScheduler(
16021677
}
16031678
}
16041679
callback_pool_ = is_->EnsembleCallbackPool();
1680+
1681+
// Parse backpressure configuration. Limits concurrent responses from
1682+
// decoupled steps to prevent memory growth.
1683+
if (config.parameters().contains("max_ensemble_inflight_responses")) {
1684+
const auto& param =
1685+
config.parameters().at("max_ensemble_inflight_responses");
1686+
const std::string& value = param.string_value();
1687+
try {
1688+
const int64_t size = std::stoll(value);
1689+
if (size > 0) {
1690+
info_->max_inflight_responses_ = static_cast<size_t>(size);
1691+
LOG_INFO << "Ensemble model '" << config.name()
1692+
<< "' configured with max_ensemble_inflight_responses: "
1693+
<< info_->max_inflight_responses_;
1694+
} else {
1695+
LOG_ERROR
1696+
<< "Ignoring 'max_ensemble_inflight_responses' for ensemble model '"
1697+
<< config.name() << "': value must be positive, got " << size;
1698+
}
1699+
}
1700+
catch (const std::invalid_argument& ia) {
1701+
LOG_ERROR
1702+
<< "Failed to parse 'max_ensemble_inflight_responses' for ensemble '"
1703+
<< config.name() << "': " << ia.what();
1704+
}
1705+
}
16051706
}
16061707

16071708
EnsembleScheduler::~EnsembleScheduler()

src/ensemble_scheduler/ensemble_scheduler.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ struct EnsembleInfo {
8383

8484
// backward path, ensemble tensor to the step that provides its data
8585
std::unordered_map<std::string, size_t> tensor_to_prev_step_;
86+
87+
// Maximum concurrent inflight responses from steps to downstream
88+
// consumers. Prevents memory growth by blocking producers when limit reached.
89+
// Value of 0 means unlimited (default). Configured via parameter
90+
// 'max_ensemble_inflight_responses' in ensemble config.pbtxt.
91+
size_t max_inflight_responses_ = 0;
8692
};
8793

8894
// Scheduler that implements ensemble scheduling.

0 commit comments

Comments
 (0)