Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "ensemble_scheduler.h"

#include <condition_variable>
#include <mutex>

#include "cuda_utils.h"
Expand All @@ -45,6 +46,9 @@ class EnsembleContext;

using IterationCount = size_t;

// Timeout for mutex blocking to prevent potential deadlocks
constexpr int kMutexTimeoutSeconds = 300;

// Check if the model is configured to preserve the order of responses.
// This is critical for async execution of ResponseComplete callbacks.
inline bool
Expand Down Expand Up @@ -370,6 +374,13 @@ class EnsembleContext {

size_t inflight_step_counter_;

// Backpressure support: Limits memory growth from decoupled models.
// Tracks inflight responses per step; blocks producers when downstream
// consumers are overloaded. Only active if max_inflight_responses_ > 0.
std::vector<size_t> step_inflight_response_counts_;
std::vector<std::unique_ptr<std::mutex>> step_mutexes_;
std::vector<std::unique_ptr<std::condition_variable>> step_cv_vec_;

// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
Expand Down Expand Up @@ -505,6 +516,20 @@ EnsembleContext::EnsembleContext(
}
}

// Initialize backpressure tracking if enabled.
size_t num_steps = info_->steps_.size();
step_inflight_response_counts_.resize(num_steps, 0);

if (info_->max_inflight_responses_ > 0) {
step_mutexes_.resize(num_steps);
step_cv_vec_.resize(num_steps);

for (size_t i = 0; i < num_steps; i++) {
step_mutexes_[i] = std::make_unique<std::mutex>();
step_cv_vec_[i] = std::make_unique<std::condition_variable>();
}
}

if (ensemble_status_.IsOk()) {
request_id_ = lrequest->Id();
correlation_id_ = lrequest->CorrelationId();
Expand Down Expand Up @@ -907,6 +932,16 @@ EnsembleContext::UpdateEnsembleState(
if (completed_step->response_flags_ &
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
inflight_step_counter_--;

size_t completed_step_idx = completed_step->step_idx_;

// Decrement step_inflight_response_counts_, then notify any producer
// threads blocked waiting for this step's capacity
if (info_->max_inflight_responses_ > 0 && !step_cv_vec_.empty()) {
std::lock_guard<std::mutex> lk(*step_mutexes_[completed_step_idx]);
step_inflight_response_counts_[completed_step_idx]--;
step_cv_vec_[completed_step_idx]->notify_one();
}
}
RETURN_IF_ERROR(ConsumeResponse(completed_step));
updated_tensors->swap(completed_step->updated_tensors_);
Expand Down Expand Up @@ -950,6 +985,13 @@ EnsembleContext::GetNextSteps(
for (const auto& idx : next_step_idx) {
steps->emplace_back();
RETURN_IF_ERROR(InitStep(idx.first, idx.second, &(steps->back())));

// Track as inflight. Checked by producers for backpressure; decremented on
// completion.
if (info_->max_inflight_responses_ > 0 && !step_mutexes_.empty()) {
std::lock_guard<std::mutex> lk(*step_mutexes_[idx.first]);
step_inflight_response_counts_[idx.first]++;
}
}
inflight_step_counter_ += steps->size();

Expand Down Expand Up @@ -1392,6 +1434,39 @@ EnsembleContext::ScheduleSteps(
{
for (auto& step : steps) {
step->ctx_ = context;
size_t this_step_idx = step->step_idx_;

// Block if this step is overloaded.
if (context->info_->max_inflight_responses_ > 0 &&
!context->step_cv_vec_.empty()) {
std::unique_lock<std::mutex> lk(*context->step_mutexes_[this_step_idx]);

auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);
auto cancelled = [&]() {
auto& req = context->request_tracker_->Request();
return (req == nullptr) || req->IsCancelled();
};

bool capacity_available =
context->step_cv_vec_[this_step_idx]->wait_for(lk, timeout, [&] {
return cancelled() ||
(context->step_inflight_response_counts_[this_step_idx] <
context->info_->max_inflight_responses_);
});

// Log error only if timeout occurred (not cancellation).
if (!capacity_available && !cancelled()) {
LOG_ERROR << "[Internal Error] Ensemble '"
<< context->info_->ensemble_name_
<< "' unable to schedule step " << this_step_idx
<< " (inflight: "
<< context->step_inflight_response_counts_[this_step_idx]
<< " >= limit: " << context->info_->max_inflight_responses_
<< ") for " << kMutexTimeoutSeconds
<< " seconds. Proceeding to avoid deadlock.";
}
}

bool should_schedule = false;
// Must release lock before InferAsync to avoid deadlock, as the same thread
// will be calling request/response callbacks on cache hits, which will
Expand Down Expand Up @@ -1602,6 +1677,18 @@ EnsembleScheduler::EnsembleScheduler(
}
}
callback_pool_ = is_->EnsembleCallbackPool();

// Backpressure configuration from protobuf field. Limits concurrent responses
// from decoupled steps to prevent memory growth. Value of 0 means unlimited.
if (config.has_ensemble_scheduling()) {
info_->max_inflight_responses_ =
config.ensemble_scheduling().max_inflight_responses();
if (info_->max_inflight_responses_ > 0) {
LOG_INFO << "Ensemble model '" << config.name()
<< "' configured with max_inflight_responses: "
<< info_->max_inflight_responses_;
}
}
}

EnsembleScheduler::~EnsembleScheduler()
Expand Down
6 changes: 6 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ struct EnsembleInfo {

// backward path, ensemble tensor to the step that provides its data
std::unordered_map<std::string, size_t> tensor_to_prev_step_;

// Maximum concurrent inflight responses from steps to downstream consumers.
// Prevents memory growth by blocking producers when limit is reached.
// Default value is 0, which indicates unlimited (no backpressure applied).
// Configured via 'max_inflight_responses' parameter in config.pbtxt.
size_t max_inflight_responses_ = 0;
};

// Scheduler that implements ensemble scheduling.
Expand Down