Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "ensemble_scheduler.h"

#include <condition_variable>
#include <mutex>

#include "cuda_utils.h"
Expand All @@ -45,6 +46,9 @@ class EnsembleContext;

using IterationCount = size_t;

// Timeout for mutex blocking to prevent potential deadlocks
constexpr int kMutexTimeoutSeconds = 300;

// Check if the model is configured to preserve the order of responses.
// This is critical for async execution of ResponseComplete callbacks.
inline bool
Expand Down Expand Up @@ -370,6 +374,13 @@ class EnsembleContext {

size_t inflight_step_counter_;

// Backpressure support: Limits memory growth from decoupled models.
// Tracks inflight responses per step; blocks producers when downstream
// consumers are overloaded. Only active if max_inflight_responses_ > 0.
std::vector<size_t> step_inflight_response_counts_;
std::vector<std::unique_ptr<std::mutex>> step_mutexes_;
std::vector<std::unique_ptr<std::condition_variable>> step_cvs_;

// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
Expand Down Expand Up @@ -505,6 +516,20 @@ EnsembleContext::EnsembleContext(
}
}

// Initialize backpressure tracking if enabled.
size_t num_steps = info_->steps_.size();
step_inflight_response_counts_.resize(num_steps, 0);

if (info_->max_inflight_responses_ > 0) {
step_mutexes_.resize(num_steps);
step_cvs_.resize(num_steps);

for (size_t i = 0; i < num_steps; i++) {
step_mutexes_[i] = std::make_unique<std::mutex>();
step_cvs_[i] = std::make_unique<std::condition_variable>();
}
}

if (ensemble_status_.IsOk()) {
request_id_ = lrequest->Id();
correlation_id_ = lrequest->CorrelationId();
Expand Down Expand Up @@ -669,6 +694,55 @@ EnsembleContext::ResponseComplete(
auto pool = step_raw_ptr->ctx_->CallbackPool();
auto fn = [response, flags, step_raw_ptr]() {
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
auto& context = step_ptr->ctx_;
size_t this_step_idx = step_ptr->step_idx_;
const auto& istep = context->info_->steps_[this_step_idx];

// Block this producer if downstream consumers are overloaded.
// Prevents memory exhaustion by limiting concurrent inflight responses.
if (context->info_->max_inflight_responses_ > 0 &&
!context->step_cvs_.empty()) {
for (const auto& output_pair : istep.output_to_tensor_) {
const auto& tensor_name = output_pair.second;
const auto& downstream_steps = (*context->tensor_to_step_)[tensor_name];

for (const auto& downstream_step_idx : downstream_steps) {
std::unique_lock<std::mutex> lk(
*context->step_mutexes_[downstream_step_idx]);

// Block if downstream inflight count >= limit. Timeout to prevent
// potential deadlock. Unblocks when downstream completes a request
// or request is cancelled.
auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);
auto cancelled = [&]() {
auto& req = context->request_tracker_->Request();
return (req == nullptr) || req->IsCancelled();
};

bool capacity_available =
context->step_cvs_[downstream_step_idx]->wait_for(
lk, timeout, [&] {
return cancelled() ||
(context->step_inflight_response_counts_
[downstream_step_idx] <
context->info_->max_inflight_responses_);
});

// Log error only if timeout occurred (not cancellation).
if (!capacity_available && !cancelled()) {
LOG_ERROR
<< "[Internal Error] Ensemble '"
<< context->info_->ensemble_name_ << "' step " << this_step_idx
<< " blocked waiting for downstream step "
<< downstream_step_idx << " (inflight: "
<< context->step_inflight_response_counts_[downstream_step_idx]
<< " >= limit: " << context->info_->max_inflight_responses_
<< ") for " << kMutexTimeoutSeconds
<< " seconds. Proceeding to avoid deadlock.";
}
}
}
}
step_ptr->response_flags_ = flags;
step_ptr->response_ = response;

Expand Down Expand Up @@ -907,6 +981,16 @@ EnsembleContext::UpdateEnsembleState(
if (completed_step->response_flags_ &
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
inflight_step_counter_--;

size_t completed_step_idx = completed_step->step_idx_;

// Decrement step_inflight_response_counts_, then notify any producer
// threads blocked waiting for this step's capacity
if (info_->max_inflight_responses_ > 0 && !step_cvs_.empty()) {
std::lock_guard<std::mutex> lk(*step_mutexes_[completed_step_idx]);
step_inflight_response_counts_[completed_step_idx]--;
step_cvs_[completed_step_idx]->notify_one();
}
}
RETURN_IF_ERROR(ConsumeResponse(completed_step));
updated_tensors->swap(completed_step->updated_tensors_);
Expand Down Expand Up @@ -950,6 +1034,13 @@ EnsembleContext::GetNextSteps(
for (const auto& idx : next_step_idx) {
steps->emplace_back();
RETURN_IF_ERROR(InitStep(idx.first, idx.second, &(steps->back())));

// Track as inflight. Checked by producers for backpressure; decremented on
// completion.
if (info_->max_inflight_responses_ > 0 && !step_mutexes_.empty()) {
std::lock_guard<std::mutex> lk(*step_mutexes_[idx.first]);
step_inflight_response_counts_[idx.first]++;
}
}
inflight_step_counter_ += steps->size();

Expand Down Expand Up @@ -1602,6 +1693,36 @@ EnsembleScheduler::EnsembleScheduler(
}
}
callback_pool_ = is_->EnsembleCallbackPool();

// Parse backpressure configuration. Limits concurrent responses from
// decoupled steps to prevent memory growth.
if (config.parameters().contains("max_ensemble_inflight_responses")) {
const auto& param =
config.parameters().at("max_ensemble_inflight_responses");
const std::string& value = param.string_value();
try {
const int64_t size = std::stoll(value);
if (size > 0) {
info_->max_inflight_responses_ = static_cast<size_t>(size);
LOG_INFO << "Ensemble model '" << config.name()
<< "' configured with max_ensemble_inflight_responses: "
<< info_->max_inflight_responses_;
} else {
LOG_ERROR
<< "Ensemble model '" << config.name()
<< "': max_ensemble_inflight_responses must be greater than 0. "
<< "Received '" << size << "'. Falling back to default value ("
<< info_->max_inflight_responses_ << ").";
}
}
catch (const std::exception& e) {
LOG_ERROR << "Ensemble model '" << config.name()
<< "': failed to parse max_ensemble_inflight_responses='"
<< value << "': " << e.what()
<< ". Falling back to default value ("
<< info_->max_inflight_responses_ << ").";
}
}
}

EnsembleScheduler::~EnsembleScheduler()
Expand Down
6 changes: 6 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ struct EnsembleInfo {

// backward path, ensemble tensor to the step that provides its data
std::unordered_map<std::string, size_t> tensor_to_prev_step_;

// Maximum concurrent inflight responses from steps to downstream
// consumers. Prevents memory growth by blocking producers when limit reached.
// Value of 0 means unlimited (default). Configured via parameter
// 'max_ensemble_inflight_responses' in ensemble config.pbtxt.
size_t max_inflight_responses_ = 0;
};

// Scheduler that implements ensemble scheduling.
Expand Down