|
| 1 | +package sttp.tapir.server.metrics.prometheus_simpleclient |
| 2 | + |
| 3 | +import io.prometheus.client.exporter.common.TextFormat |
| 4 | +import io.prometheus.client.{CollectorRegistry, Counter, Gauge, Histogram} |
| 5 | +import sttp.monad.MonadError |
| 6 | +import sttp.tapir.CodecFormat.TextPlain |
| 7 | +import sttp.tapir._ |
| 8 | +import sttp.tapir.server.ServerEndpoint |
| 9 | +import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor |
| 10 | +import sttp.tapir.server.metrics.{EndpointMetric, Metric, MetricLabels} |
| 11 | + |
| 12 | +import java.io.StringWriter |
| 13 | +import java.time.{Clock, Duration} |
| 14 | + |
| 15 | +case class PrometheusMetrics[F[_]]( |
| 16 | + namespace: String = "tapir", |
| 17 | + registry: CollectorRegistry = CollectorRegistry.defaultRegistry, |
| 18 | + metrics: List[Metric[F, _]] = List.empty[Metric[F, _]], |
| 19 | + endpointPrefix: EndpointInput[Unit] = "metrics" |
| 20 | +) { |
| 21 | + import PrometheusMetrics._ |
| 22 | + |
| 23 | + /** An endpoint exposing the current metric values. */ |
| 24 | + lazy val metricsEndpoint: ServerEndpoint[Any, F] = ServerEndpoint.public( |
| 25 | + endpoint.get.in(endpointPrefix).out(plainBody[CollectorRegistry]), |
| 26 | + (monad: MonadError[F]) => (_: Unit) => monad.eval(Right(registry): Either[Unit, CollectorRegistry]) |
| 27 | + ) |
| 28 | + |
| 29 | + /** Registers a `$namespace_request_active{path, method}` gauge (assuming default labels). */ |
| 30 | + def addRequestsActive(labels: MetricLabels = MetricLabels.Default): PrometheusMetrics[F] = |
| 31 | + copy(metrics = metrics :+ requestActive(registry, namespace, labels)) |
| 32 | + |
| 33 | + /** Registers a `$namespace_request_total{path, method, status}` counter (assuming default labels). */ |
| 34 | + def addRequestsTotal(labels: MetricLabels = MetricLabels.Default): PrometheusMetrics[F] = |
| 35 | + copy(metrics = metrics :+ requestTotal(registry, namespace, labels)) |
| 36 | + |
| 37 | + /** Registers a `$namespace_request_duration_seconds{path, method, status, phase}` histogram (assuming default labels). */ |
| 38 | + def addRequestsDuration( |
| 39 | + labels: MetricLabels = MetricLabels.Default, |
| 40 | + clock: Clock = Clock.systemUTC(), |
| 41 | + bucketsOverride: List[Double] = List.empty |
| 42 | + ): PrometheusMetrics[F] = |
| 43 | + copy(metrics = metrics :+ requestDuration(registry, namespace, labels, clock, bucketsOverride)) |
| 44 | + |
| 45 | + /** Registers a custom metric. */ |
| 46 | + def addCustom(m: Metric[F, _]): PrometheusMetrics[F] = copy(metrics = metrics :+ m) |
| 47 | + |
| 48 | + /** The interceptor which can be added to a server's options, to enable metrics collection. */ |
| 49 | + def metricsInterceptor(ignoreEndpoints: Seq[AnyEndpoint] = Seq.empty): MetricsRequestInterceptor[F] = |
| 50 | + new MetricsRequestInterceptor[F](metrics, ignoreEndpoints :+ metricsEndpoint.endpoint) |
| 51 | +} |
| 52 | + |
| 53 | +object PrometheusMetrics { |
| 54 | + |
| 55 | + implicit val schemaForCollectorRegistry: Schema[CollectorRegistry] = Schema.string[CollectorRegistry] |
| 56 | + |
| 57 | + implicit val collectorRegistryCodec: Codec[String, CollectorRegistry, CodecFormat.TextPlain] = |
| 58 | + Codec.anyString(TextPlain())(_ => DecodeResult.Value(new CollectorRegistry()))(r => { |
| 59 | + val output = new StringWriter() |
| 60 | + TextFormat.write004(output, r.metricFamilySamples) |
| 61 | + output.close() |
| 62 | + output.toString |
| 63 | + }) |
| 64 | + |
| 65 | + /** Using the default namespace and labels, registers the following metrics: |
| 66 | + * |
| 67 | + * - `$namespace_request_active{path, method}` (gauge) |
| 68 | + * - `$namespace_request_total{path, method, status}` (counter) |
| 69 | + * - `$namespace_request_duration_seconds{path, method, status, phase}` (histogram) |
| 70 | + * |
| 71 | + * Status is by default the status code class (1xx, 2xx, etc.), and phase can be either `headers` or `body` - request duration is |
| 72 | + * measured separately up to the point where the headers are determined, and then once again when the whole response body is complete. |
| 73 | + */ |
| 74 | + def default[F[_]]( |
| 75 | + namespace: String = "tapir", |
| 76 | + registry: CollectorRegistry = CollectorRegistry.defaultRegistry, |
| 77 | + labels: MetricLabels = MetricLabels.Default |
| 78 | + ): PrometheusMetrics[F] = |
| 79 | + PrometheusMetrics( |
| 80 | + namespace, |
| 81 | + registry, |
| 82 | + List( |
| 83 | + requestActive(registry, namespace, labels), |
| 84 | + requestTotal(registry, namespace, labels), |
| 85 | + requestDuration(registry, namespace, labels) |
| 86 | + ) |
| 87 | + ) |
| 88 | + |
| 89 | + def requestActive[F[_]](registry: CollectorRegistry, namespace: String, labels: MetricLabels): Metric[F, Gauge] = |
| 90 | + Metric[F, Gauge]( |
| 91 | + Gauge |
| 92 | + .build() |
| 93 | + .namespace(namespace) |
| 94 | + .name("request_active") |
| 95 | + .help("Active HTTP requests") |
| 96 | + .labelNames(labels.namesForRequest: _*) |
| 97 | + .create() |
| 98 | + .register(registry), |
| 99 | + onRequest = { (req, gauge, m) => |
| 100 | + m.unit { |
| 101 | + EndpointMetric() |
| 102 | + .onEndpointRequest { ep => m.eval(gauge.labels(labels.valuesForRequest(ep, req): _*).inc()) } |
| 103 | + .onResponseBody { (ep, _) => m.eval(gauge.labels(labels.valuesForRequest(ep, req): _*).dec()) } |
| 104 | + .onException { (ep, _) => m.eval(gauge.labels(labels.valuesForRequest(ep, req): _*).dec()) } |
| 105 | + } |
| 106 | + } |
| 107 | + ) |
| 108 | + |
| 109 | + def requestTotal[F[_]](registry: CollectorRegistry, namespace: String, labels: MetricLabels): Metric[F, Counter] = |
| 110 | + Metric[F, Counter]( |
| 111 | + Counter |
| 112 | + .build() |
| 113 | + .namespace(namespace) |
| 114 | + .name("request_total") |
| 115 | + .help("Total HTTP requests") |
| 116 | + .labelNames(labels.namesForRequest ++ labels.namesForResponse: _*) |
| 117 | + .register(registry), |
| 118 | + onRequest = { (req, counter, m) => |
| 119 | + m.unit { |
| 120 | + EndpointMetric() |
| 121 | + .onResponseBody { (ep, res) => |
| 122 | + m.eval(counter.labels(labels.valuesForRequest(ep, req) ++ labels.valuesForResponse(res): _*).inc()) |
| 123 | + } |
| 124 | + .onException { (ep, ex) => m.eval(counter.labels(labels.valuesForRequest(ep, req) ++ labels.valuesForResponse(ex): _*).inc()) } |
| 125 | + } |
| 126 | + } |
| 127 | + ) |
| 128 | + |
| 129 | + def requestDuration[F[_]]( |
| 130 | + registry: CollectorRegistry, |
| 131 | + namespace: String, |
| 132 | + labels: MetricLabels, |
| 133 | + clock: Clock = Clock.systemUTC(), |
| 134 | + bucketsOverride: List[Double] = List.empty |
| 135 | + ): Metric[F, Histogram] = |
| 136 | + Metric[F, Histogram]( |
| 137 | + (if (bucketsOverride.nonEmpty) Histogram.build().buckets(bucketsOverride: _*) else Histogram.build()) |
| 138 | + .namespace(namespace) |
| 139 | + .name("request_duration_seconds") |
| 140 | + .help("Duration of HTTP requests") |
| 141 | + .labelNames(labels.namesForRequest ++ labels.namesForResponse ++ List(labels.forResponsePhase.name): _*) |
| 142 | + .register(registry), |
| 143 | + onRequest = { (req, histogram, m) => |
| 144 | + m.eval { |
| 145 | + val requestStart = clock.instant() |
| 146 | + def duration = Duration.between(requestStart, clock.instant()).toMillis.toDouble / 1000.0 |
| 147 | + EndpointMetric() |
| 148 | + .onResponseHeaders { (ep, res) => |
| 149 | + m.eval( |
| 150 | + histogram |
| 151 | + .labels( |
| 152 | + labels.valuesForRequest(ep, req) ++ labels.valuesForResponse(res) ++ List(labels.forResponsePhase.headersValue): _* |
| 153 | + ) |
| 154 | + .observe(duration) |
| 155 | + ) |
| 156 | + } |
| 157 | + .onResponseBody { (ep, res) => |
| 158 | + m.eval( |
| 159 | + histogram |
| 160 | + .labels(labels.valuesForRequest(ep, req) ++ labels.valuesForResponse(res) ++ List(labels.forResponsePhase.bodyValue): _*) |
| 161 | + .observe(duration) |
| 162 | + ) |
| 163 | + } |
| 164 | + .onException { (ep, ex) => |
| 165 | + m.eval( |
| 166 | + histogram |
| 167 | + .labels(labels.valuesForRequest(ep, req) ++ labels.valuesForResponse(ex) ++ List(labels.forResponsePhase.bodyValue): _*) |
| 168 | + .observe(duration) |
| 169 | + ) |
| 170 | + } |
| 171 | + } |
| 172 | + } |
| 173 | + ) |
| 174 | +} |
0 commit comments