Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "ensemble_scheduler.h"

#include <condition_variable>
#include <mutex>

#include "cuda_utils.h"
Expand All @@ -45,6 +46,9 @@ class EnsembleContext;

using IterationCount = size_t;

// Timeout for mutex blocking to prevent potential deadlocks
constexpr int kMutexTimeoutSeconds = 300;

// Check if the model is configured to preserve the order of responses.
// This is critical for async execution of ResponseComplete callbacks.
inline bool
Expand Down Expand Up @@ -370,6 +374,13 @@ class EnsembleContext {

size_t inflight_step_counter_;

// Backpressure support: Limits memory growth from decoupled models.
// Tracks inflight responses per step; blocks producers when downstream
// consumers are overloaded. Only active if max_inflight_responses_ > 0.
std::vector<size_t> step_inflight_response_counts_;
std::vector<std::unique_ptr<std::mutex>> step_mutexes_;
std::vector<std::unique_ptr<std::condition_variable>> step_cvs_;

// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
Expand Down Expand Up @@ -505,6 +516,20 @@ EnsembleContext::EnsembleContext(
}
}

// Initialize backpressure tracking if enabled.
size_t num_steps = info_->steps_.size();
step_inflight_response_counts_.resize(num_steps, 0);

if (info_->max_inflight_responses_ > 0) {
step_mutexes_.resize(num_steps);
step_cvs_.resize(num_steps);

for (size_t i = 0; i < num_steps; i++) {
step_mutexes_[i] = std::make_unique<std::mutex>();
step_cvs_[i] = std::make_unique<std::condition_variable>();
}
}

if (ensemble_status_.IsOk()) {
request_id_ = lrequest->Id();
correlation_id_ = lrequest->CorrelationId();
Expand Down Expand Up @@ -669,6 +694,55 @@ EnsembleContext::ResponseComplete(
auto pool = step_raw_ptr->ctx_->CallbackPool();
auto fn = [response, flags, step_raw_ptr]() {
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
auto& context = step_ptr->ctx_;
size_t this_step_idx = step_ptr->step_idx_;
const auto& istep = context->info_->steps_[this_step_idx];

// Block this producer if downstream consumers are overloaded.
// Prevents memory exhaustion by limiting concurrent inflight responses.
if (context->info_->max_inflight_responses_ > 0 &&
!context->step_cvs_.empty()) {
for (const auto& output_pair : istep.output_to_tensor_) {
const auto& tensor_name = output_pair.second;
const auto& downstream_steps = (*context->tensor_to_step_)[tensor_name];

for (const auto& downstream_step_idx : downstream_steps) {
std::unique_lock<std::mutex> lk(
*context->step_mutexes_[downstream_step_idx]);

// Block if downstream inflight count >= limit. Timeout to prevent
// potential deadlock. Unblocks when downstream completes a request
// or request is cancelled.
auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);
auto cancelled = [&]() {
auto& req = context->request_tracker_->Request();
return (req == nullptr) || req->IsCancelled();
};

bool capacity_available =
context->step_cvs_[downstream_step_idx]->wait_for(
lk, timeout, [&] {
return cancelled() ||
(context->step_inflight_response_counts_
[downstream_step_idx] <
context->info_->max_inflight_responses_);
});

// Log error only if timeout occurred (not cancellation).
if (!capacity_available && !cancelled()) {
LOG_ERROR
<< "[Internal Error] Ensemble '"
<< context->info_->ensemble_name_ << "' step " << this_step_idx
<< " blocked waiting for downstream step "
<< downstream_step_idx << " (inflight: "
<< context->step_inflight_response_counts_[downstream_step_idx]
<< " >= limit: " << context->info_->max_inflight_responses_
<< ") for " << kMutexTimeoutSeconds
<< " seconds. Proceeding to avoid deadlock.";
}
}
}
}
step_ptr->response_flags_ = flags;
step_ptr->response_ = response;

Expand Down Expand Up @@ -907,6 +981,16 @@ EnsembleContext::UpdateEnsembleState(
if (completed_step->response_flags_ &
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
inflight_step_counter_--;

size_t completed_step_idx = completed_step->step_idx_;

// Decrement step_inflight_response_counts_, then notify any producer
// threads blocked waiting for this step's capacity
if (info_->max_inflight_responses_ > 0 && !step_cvs_.empty()) {
std::lock_guard<std::mutex> lk(*step_mutexes_[completed_step_idx]);
step_inflight_response_counts_[completed_step_idx]--;
step_cvs_[completed_step_idx]->notify_one();
}
}
RETURN_IF_ERROR(ConsumeResponse(completed_step));
updated_tensors->swap(completed_step->updated_tensors_);
Expand Down Expand Up @@ -950,6 +1034,13 @@ EnsembleContext::GetNextSteps(
for (const auto& idx : next_step_idx) {
steps->emplace_back();
RETURN_IF_ERROR(InitStep(idx.first, idx.second, &(steps->back())));

// Track as inflight. Checked by producers for backpressure; decremented on
// completion.
if (info_->max_inflight_responses_ > 0 && !step_mutexes_.empty()) {
std::lock_guard<std::mutex> lk(*step_mutexes_[idx.first]);
step_inflight_response_counts_[idx.first]++;
}
}
inflight_step_counter_ += steps->size();

Expand Down Expand Up @@ -1445,12 +1536,52 @@ EnsembleContext::ScheduleSteps(

} // namespace

Status
EnsembleScheduler::ValidateConfig(const inference::ModelConfig& config)
{
// Validate max_ensemble_inflight_responses parameter if present
if (config.parameters().contains("max_ensemble_inflight_responses")) {
const auto& param =
config.parameters().at("max_ensemble_inflight_responses");
const std::string& value = param.string_value();

try {
const int parsed = std::stoi(value);
if (parsed <= 0) {
return Status(
Status::Code::INVALID_ARG,
"Invalid 'max_ensemble_inflight_responses' for ensemble model '" +
config.name() + "': value must be positive, got " +
std::to_string(parsed));
}
}
catch (const std::out_of_range& e) {
return Status(
Status::Code::INVALID_ARG,
"Invalid 'max_ensemble_inflight_responses' for ensemble model '" +
config.name() + "': value exceeds maximum allowed (" +
std::to_string(INT_MAX) + ")");
}
catch (const std::invalid_argument& e) {
return Status(
Status::Code::INVALID_ARG,
"Invalid 'max_ensemble_inflight_responses' for ensemble model '" +
config.name() + "': cannot parse value '" + value + "'");
}
}

return Status::Success;
}

Status
EnsembleScheduler::Create(
InferenceStatsAggregator* const stats_aggregator,
InferenceServer* const server, const ModelIdentifier& model_id,
const inference::ModelConfig& config, std::unique_ptr<Scheduler>* scheduler)
{
// Validate configuration before constructing scheduler
RETURN_IF_ERROR(ValidateConfig(config));

scheduler->reset(
new EnsembleScheduler(stats_aggregator, server, model_id, config));
return Status::Success;
Expand Down Expand Up @@ -1602,6 +1733,19 @@ EnsembleScheduler::EnsembleScheduler(
}
}
callback_pool_ = is_->EnsembleCallbackPool();

// Parse backpressure configuration. Limits concurrent responses from
// decoupled steps to prevent memory growth.
// Configuration is already validated in ValidateConfig()
if (config.parameters().contains("max_ensemble_inflight_responses")) {
const auto& param =
config.parameters().at("max_ensemble_inflight_responses");
info_->max_inflight_responses_ =
static_cast<size_t>(std::stoi(param.string_value()));
LOG_INFO << "Ensemble model '" << config.name()
<< "' configured with max_ensemble_inflight_responses: "
<< info_->max_inflight_responses_;
}
}

EnsembleScheduler::~EnsembleScheduler()
Expand Down
10 changes: 10 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ struct EnsembleInfo {

// backward path, ensemble tensor to the step that provides its data
std::unordered_map<std::string, size_t> tensor_to_prev_step_;

// Maximum concurrent inflight responses from steps to downstream
// consumers. Prevents memory growth by blocking producers when limit reached.
// Value of 0 means unlimited (default). Configured via parameter
// 'max_ensemble_inflight_responses' in ensemble config.pbtxt.
size_t max_inflight_responses_ = 0;
};

// Scheduler that implements ensemble scheduling.
Expand Down Expand Up @@ -116,6 +122,10 @@ class EnsembleScheduler : public Scheduler {
InferenceServer* const server, const ModelIdentifier& model_id,
const inference::ModelConfig& config);

// Validates ensemble configuration parameters before construction.
// Returns error Status if configuration is invalid.
static Status ValidateConfig(const inference::ModelConfig& config);

void CacheLookUp(
std::unique_ptr<InferenceRequest>& request,
std::unique_ptr<InferenceResponse>& cached_response);
Expand Down