|
28 | 28 |
|
29 | 29 | #include "ensemble_scheduler.h" |
30 | 30 |
|
| 31 | +#include <condition_variable> |
31 | 32 | #include <mutex> |
32 | 33 |
|
33 | 34 | #include "cuda_utils.h" |
@@ -370,6 +371,13 @@ class EnsembleContext { |
370 | 371 |
|
371 | 372 | size_t inflight_step_counter_; |
372 | 373 |
|
| 374 | + // Backpressure support: Limits memory growth from decoupled models. |
| 375 | + // Tracks inflight responses per step; blocks producers when downstream |
| 376 | + // consumers are overloaded. Only active if max_inflight_responses_ > 0. |
| 377 | + std::vector<size_t> step_inflight_response_counts_; |
| 378 | + std::vector<std::unique_ptr<std::mutex>> step_mutexes_; |
| 379 | + std::vector<std::unique_ptr<std::condition_variable>> step_cvs_; |
| 380 | + |
373 | 381 | // pointer that either points to 'pruned_tensor_to_step_' or to |
374 | 382 | // 'info_->tensor_to_step_' if all ensemble outputs are requested |
375 | 383 | std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_; |
@@ -505,6 +513,20 @@ EnsembleContext::EnsembleContext( |
505 | 513 | } |
506 | 514 | } |
507 | 515 |
|
| 516 | + // Initialize backpressure tracking if enabled. |
| 517 | + size_t num_steps = info_->steps_.size(); |
| 518 | + step_inflight_response_counts_.resize(num_steps, 0); |
| 519 | + |
| 520 | + if (info_->max_inflight_responses_ > 0) { |
| 521 | + step_mutexes_.resize(num_steps); |
| 522 | + step_cvs_.resize(num_steps); |
| 523 | + |
| 524 | + for (size_t i = 0; i < num_steps; i++) { |
| 525 | + step_mutexes_[i].reset(new std::mutex()); |
| 526 | + step_cvs_[i].reset(new std::condition_variable()); |
| 527 | + } |
| 528 | + } |
| 529 | + |
508 | 530 | if (ensemble_status_.IsOk()) { |
509 | 531 | request_id_ = lrequest->Id(); |
510 | 532 | correlation_id_ = lrequest->CorrelationId(); |
@@ -669,6 +691,46 @@ EnsembleContext::ResponseComplete( |
669 | 691 | auto pool = step_raw_ptr->ctx_->CallbackPool(); |
670 | 692 | auto fn = [response, flags, step_raw_ptr]() { |
671 | 693 | auto step_ptr = std::unique_ptr<Step>(step_raw_ptr); |
| 694 | + auto& context = step_ptr->ctx_; |
| 695 | + size_t this_step_idx = step_ptr->step_idx_; |
| 696 | + const auto& istep = context->info_->steps_[this_step_idx]; |
| 697 | + |
| 698 | + // Block this producer if downstream consumers are overloaded. |
| 699 | + // Prevents memory exhaustion by limiting concurrent inflight responses. |
| 700 | + if (context->info_->max_inflight_responses_ > 0 && |
| 701 | + !context->step_cvs_.empty()) { |
| 702 | + for (const auto& output_pair : istep.output_to_tensor_) { |
| 703 | + const auto& tensor_name = output_pair.second; |
| 704 | + const auto& downstream_steps = (*context->tensor_to_step_)[tensor_name]; |
| 705 | + |
| 706 | + for (const auto& downstream_step_idx : downstream_steps) { |
| 707 | + std::unique_lock<std::mutex> lk( |
| 708 | + *context->step_mutexes_[downstream_step_idx]); |
| 709 | + |
| 710 | + // Block if downstream inflight count >= limit. Timeout after 300s to |
| 711 | + // prevent any deadlock. Unblocks when downstream completes a request. |
| 712 | + auto timeout = std::chrono::seconds(300); |
| 713 | + bool capacity_available = |
| 714 | + context->step_cvs_[downstream_step_idx]->wait_for( |
| 715 | + lk, timeout, [&] { |
| 716 | + return context->step_inflight_response_counts_ |
| 717 | + [downstream_step_idx] < |
| 718 | + context->info_->max_inflight_responses_; |
| 719 | + }); |
| 720 | + |
| 721 | + if (!capacity_available) { |
| 722 | + LOG_ERROR |
| 723 | + << "[Internal Error] Ensemble '" |
| 724 | + << context->info_->ensemble_name_ << "' step " << this_step_idx |
| 725 | + << " blocked waiting for downstream step " |
| 726 | + << downstream_step_idx << " (inflight: " |
| 727 | + << context->step_inflight_response_counts_[downstream_step_idx] |
| 728 | + << " >= limit: " << context->info_->max_inflight_responses_ |
| 729 | + << ") for 300 seconds. Proceeding to avoid deadlock."; |
| 730 | + } |
| 731 | + } |
| 732 | + } |
| 733 | + } |
672 | 734 | step_ptr->response_flags_ = flags; |
673 | 735 | step_ptr->response_ = response; |
674 | 736 |
|
@@ -907,6 +969,15 @@ EnsembleContext::UpdateEnsembleState( |
907 | 969 | if (completed_step->response_flags_ & |
908 | 970 | TRITONSERVER_RESPONSE_COMPLETE_FINAL) { |
909 | 971 | inflight_step_counter_--; |
| 972 | + |
| 973 | + size_t completed_step_idx = completed_step->step_idx_; |
| 974 | + step_inflight_response_counts_[completed_step_idx]--; |
| 975 | + |
| 976 | + // Notify any producer threads blocked waiting for this step's capacity |
| 977 | + if (info_->max_inflight_responses_ > 0 && !step_cvs_.empty()) { |
| 978 | + std::lock_guard<std::mutex> lk(*step_mutexes_[completed_step_idx]); |
| 979 | + step_cvs_[completed_step_idx]->notify_one(); |
| 980 | + } |
910 | 981 | } |
911 | 982 | RETURN_IF_ERROR(ConsumeResponse(completed_step)); |
912 | 983 | updated_tensors->swap(completed_step->updated_tensors_); |
@@ -950,6 +1021,10 @@ EnsembleContext::GetNextSteps( |
950 | 1021 | for (const auto& idx : next_step_idx) { |
951 | 1022 | steps->emplace_back(); |
952 | 1023 | RETURN_IF_ERROR(InitStep(idx.first, idx.second, &(steps->back()))); |
| 1024 | + |
| 1025 | + // Track as inflight. Checked by producers for backpressure; decremented on |
| 1026 | + // completion. |
| 1027 | + step_inflight_response_counts_[idx.first]++; |
953 | 1028 | } |
954 | 1029 | inflight_step_counter_ += steps->size(); |
955 | 1030 |
|
@@ -1602,6 +1677,32 @@ EnsembleScheduler::EnsembleScheduler( |
1602 | 1677 | } |
1603 | 1678 | } |
1604 | 1679 | callback_pool_ = is_->EnsembleCallbackPool(); |
| 1680 | + |
| 1681 | + // Parse backpressure configuration. Limits concurrent responses from |
| 1682 | + // decoupled steps to prevent memory growth. |
| 1683 | + if (config.parameters().contains("max_ensemble_inflight_responses")) { |
| 1684 | + const auto& param = |
| 1685 | + config.parameters().at("max_ensemble_inflight_responses"); |
| 1686 | + const std::string& value = param.string_value(); |
| 1687 | + try { |
| 1688 | + const int64_t size = std::stoll(value); |
| 1689 | + if (size > 0) { |
| 1690 | + info_->max_inflight_responses_ = static_cast<size_t>(size); |
| 1691 | + LOG_INFO << "Ensemble model '" << config.name() |
| 1692 | + << "' configured with max_ensemble_inflight_responses: " |
| 1693 | + << info_->max_inflight_responses_; |
| 1694 | + } else { |
| 1695 | + LOG_ERROR |
| 1696 | + << "Ignoring 'max_ensemble_inflight_responses' for ensemble model '" |
| 1697 | + << config.name() << "': value must be positive, got " << size; |
| 1698 | + } |
| 1699 | + } |
| 1700 | + catch (const std::invalid_argument& ia) { |
| 1701 | + LOG_ERROR |
| 1702 | + << "Failed to parse 'max_ensemble_inflight_responses' for ensemble '" |
| 1703 | + << config.name() << "': " << ia.what(); |
| 1704 | + } |
| 1705 | + } |
1605 | 1706 | } |
1606 | 1707 |
|
1607 | 1708 | EnsembleScheduler::~EnsembleScheduler() |
|
0 commit comments