Skip to content

Commit d560f64

Browse files
committed
Add collect api for metrics
1 parent a8d83d0 commit d560f64

File tree

5 files changed

+140
-1
lines changed

5 files changed

+140
-1
lines changed

include/triton/core/tritonserver.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2720,6 +2720,18 @@ TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* TRITONSERVER_MetricSet(
27202720
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* TRITONSERVER_MetricObserve(
27212721
struct TRITONSERVER_Metric* metric, double value);
27222722

2723+
/// Collect metrics.
2724+
/// Supports metrics of kind TRITONSERVER_METRIC_KIND_COUNTER,
2725+
/// TRITONSERVER_METRIC_KIND_GAUGE, TRITONSERVER_METRIC_KIND_HISTOGRAM and
2726+
/// returns TRITONSERVER_ERROR_UNSUPPORTED for unsupported
2727+
/// TRITONSERVER_MetricKind.
2728+
///
2729+
/// \param metric The metric object to collect.
2730+
/// \param value Returns the current value of the metric object.
2731+
/// \return a TRITONSERVER_Error indicating success or failure.
2732+
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* TRITONSERVER_MetricCollect(
2733+
struct TRITONSERVER_Metric* metric, void* value);
2734+
27232735
/// Get the TRITONSERVER_MetricKind of metric and its corresponding family.
27242736
///
27252737
/// \param metric The metric object to query.

src/metric_family.cc

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,21 @@ MetricFamily::Add(
7676
void* prom_metric = nullptr;
7777
switch (kind_) {
7878
case TRITONSERVER_METRIC_KIND_COUNTER: {
79+
if (buckets != nullptr) {
80+
throw std::invalid_argument(
81+
"Unexpected buckets found in counter Metric constructor.");
82+
}
7983
auto counter_family_ptr =
8084
reinterpret_cast<prometheus::Family<prometheus::Counter>*>(family_);
8185
auto counter_ptr = &counter_family_ptr->Add(label_map);
8286
prom_metric = reinterpret_cast<void*>(counter_ptr);
8387
break;
8488
}
8589
case TRITONSERVER_METRIC_KIND_GAUGE: {
90+
if (buckets != nullptr) {
91+
throw std::invalid_argument(
92+
"Unexpected buckets found in gauge Metric constructor.");
93+
}
8694
auto gauge_family_ptr =
8795
reinterpret_cast<prometheus::Family<prometheus::Gauge>*>(family_);
8896
auto gauge_ptr = &gauge_family_ptr->Add(label_map);
@@ -92,7 +100,7 @@ MetricFamily::Add(
92100
case TRITONSERVER_METRIC_KIND_HISTOGRAM: {
93101
if (buckets == nullptr) {
94102
throw std::invalid_argument(
95-
"Histogram must be constructed with bucket boundaries.");
103+
"Missing required buckets in histogram Metric constructor.");
96104
}
97105
auto histogram_family_ptr =
98106
reinterpret_cast<prometheus::Family<prometheus::Histogram>*>(family_);
@@ -394,6 +402,40 @@ Metric::Observe(double value)
394402
return nullptr; // Success
395403
}
396404

405+
TRITONSERVER_Error*
406+
Metric::Collect(prometheus::ClientMetric* value)
407+
{
408+
if (metric_ == nullptr) {
409+
return TRITONSERVER_ErrorNew(
410+
TRITONSERVER_ERROR_INTERNAL,
411+
"Could not collect metric value. Metric has been invalidated.");
412+
}
413+
414+
switch (kind_) {
415+
case TRITONSERVER_METRIC_KIND_COUNTER: {
416+
auto counter_ptr = reinterpret_cast<prometheus::Counter*>(metric_);
417+
*value = counter_ptr->Collect();
418+
break;
419+
}
420+
case TRITONSERVER_METRIC_KIND_GAUGE: {
421+
auto gauge_ptr = reinterpret_cast<prometheus::Gauge*>(metric_);
422+
*value = gauge_ptr->Collect();
423+
break;
424+
}
425+
case TRITONSERVER_METRIC_KIND_HISTOGRAM: {
426+
auto histogram_ptr = reinterpret_cast<prometheus::Histogram*>(metric_);
427+
*value = histogram_ptr->Collect();
428+
break;
429+
}
430+
default:
431+
return TRITONSERVER_ErrorNew(
432+
TRITONSERVER_ERROR_UNSUPPORTED,
433+
"Unsupported TRITONSERVER_MetricKind");
434+
}
435+
436+
return nullptr; // Success
437+
}
438+
397439
}} // namespace triton::core
398440

399441
#endif // TRITON_ENABLE_METRICS

src/metric_family.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class Metric {
9999
TRITONSERVER_Error* Increment(double value);
100100
TRITONSERVER_Error* Set(double value);
101101
TRITONSERVER_Error* Observe(double value);
102+
TRITONSERVER_Error* Collect(prometheus::ClientMetric* value);
102103

103104
// If a MetricFamily is deleted before its dependent Metric, we want to
104105
// invalidate the references so we don't access invalid memory.

src/test/metrics_api_test.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,32 @@ MetricAPIHelper(TRITONSERVER_Metric* metric, TRITONSERVER_MetricKind kind)
232232
TRITONSERVER_ErrorDelete(err);
233233
}
234234

235+
void
236+
HistogramAPIHelper(TRITONSERVER_Metric* metric)
237+
{
238+
// Observe
239+
std::vector<double> data{0.05, 1.5, 6.0};
240+
std::vector<std::uint64_t> cumulative_counts = {1, 1, 2, 2, 3, 3};
241+
double sum = 0.0;
242+
for (auto datum : data) {
243+
FAIL_TEST_IF_ERR(
244+
TRITONSERVER_MetricObserve(metric, datum), "observe metric value");
245+
sum += datum;
246+
}
247+
248+
// Collect
249+
prometheus::ClientMetric value;
250+
FAIL_TEST_IF_ERR(
251+
TRITONSERVER_MetricCollect(metric, &value),
252+
"query metric value after observe");
253+
auto hist = value.histogram;
254+
ASSERT_EQ(hist.sample_count, data.size());
255+
ASSERT_EQ(hist.sample_sum, sum);
256+
ASSERT_EQ(hist.bucket.size(), cumulative_counts.size());
257+
for (uint64_t i = 0; i < hist.bucket.size(); ++i) {
258+
ASSERT_EQ(hist.bucket[i].cumulative_count, cumulative_counts[i]);
259+
}
260+
}
235261

236262
// Test Fixture
237263
class MetricsApiTest : public ::testing::Test {
@@ -364,6 +390,52 @@ TEST_F(MetricsApiTest, TestGaugeEndToEnd)
364390
ASSERT_EQ(NumMetricMatches(server_, description), 0);
365391
}
366392

393+
// Test end-to-end flow of Generic Metrics API for Histogram metric
394+
TEST_F(MetricsApiTest, TestHistogramEndToEnd)
395+
{
396+
// Create metric family
397+
TRITONSERVER_MetricFamily* family;
398+
TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_HISTOGRAM;
399+
const char* name = "custom_histogram_example";
400+
const char* description =
401+
"this is an example histogram metric added via API.";
402+
FAIL_TEST_IF_ERR(
403+
TRITONSERVER_MetricFamilyNew(&family, kind, name, description),
404+
"Creating new metric family");
405+
406+
// Create metric
407+
TRITONSERVER_Metric* metric;
408+
std::vector<const TRITONSERVER_Parameter*> labels;
409+
labels.emplace_back(TRITONSERVER_ParameterNew(
410+
"example1", TRITONSERVER_PARAMETER_STRING, "histogram_label1"));
411+
labels.emplace_back(TRITONSERVER_ParameterNew(
412+
"example2", TRITONSERVER_PARAMETER_STRING, "histogram_label2"));
413+
std::vector<double> buckets = {0.1, 1.0, 2.5, 5.0, 10.0};
414+
FAIL_TEST_IF_ERR(
415+
TRITONSERVER_MetricNew(
416+
&metric, family, labels.data(), labels.size(),
417+
reinterpret_cast<void*>(&buckets)),
418+
"Creating new metric");
419+
for (const auto label : labels) {
420+
TRITONSERVER_ParameterDelete(const_cast<TRITONSERVER_Parameter*>(label));
421+
}
422+
423+
// Run through metric APIs and assert correctness
424+
HistogramAPIHelper(metric);
425+
426+
// Assert custom metric is reported and found in output
427+
ASSERT_EQ(NumMetricMatches(server_, description), 1);
428+
429+
// Cleanup
430+
FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric), "delete metric");
431+
FAIL_TEST_IF_ERR(
432+
TRITONSERVER_MetricFamilyDelete(family), "delete metric family");
433+
434+
// Assert custom metric/family is unregistered and no longer in output
435+
ASSERT_EQ(NumMetricMatches(server_, description), 0);
436+
}
437+
438+
367439
// Test that a duplicate metric family can't be added
368440
// with a conflicting type/kind
369441
TEST_F(MetricsApiTest, TestDupeMetricFamilyDiffKind)

src/tritonserver.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,6 +3463,18 @@ TRITONSERVER_MetricObserve(TRITONSERVER_Metric* metric, double value)
34633463
#endif // TRITON_ENABLE_METRICS
34643464
}
34653465

3466+
TRITONSERVER_Error*
3467+
TRITONSERVER_MetricCollect(TRITONSERVER_Metric* metric, void* value)
3468+
{
3469+
#ifdef TRITON_ENABLE_METRICS
3470+
return reinterpret_cast<tc::Metric*>(metric)->Collect(
3471+
reinterpret_cast<prometheus::ClientMetric*>(value));
3472+
#else
3473+
return TRITONSERVER_ErrorNew(
3474+
TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported");
3475+
#endif // TRITON_ENABLE_METRICS
3476+
}
3477+
34663478
TRITONSERVER_Error*
34673479
TRITONSERVER_GetMetricKind(
34683480
TRITONSERVER_Metric* metric, TRITONSERVER_MetricKind* kind)

0 commit comments

Comments
 (0)