|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# ruff: noqa |
| 4 | +# type: ignore |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import threading |
| 8 | +from collections.abc import Iterable |
| 9 | +from concurrent import futures |
| 10 | +from typing import Callable, Generator, Literal |
| 11 | + |
| 12 | +import grpc |
| 13 | +import pytest |
| 14 | +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( |
| 15 | + ExportTraceServiceResponse) |
| 16 | +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( |
| 17 | + TraceServiceServicer, add_TraceServiceServicer_to_server) |
| 18 | +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue |
| 19 | +from opentelemetry.sdk.environment_variables import ( |
| 20 | + OTEL_EXPORTER_OTLP_TRACES_INSECURE) |
| 21 | + |
| 22 | +from vllm import LLM, SamplingParams |
| 23 | +from vllm.tracing import SpanAttributes |
| 24 | + |
| 25 | +FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" |
| 26 | + |
| 27 | +FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', |
| 28 | + 'array_value'] |
| 29 | + |
| 30 | + |
| 31 | +def decode_value(value: AnyValue): |
| 32 | + field_decoders: dict[FieldName, Callable] = { |
| 33 | + "bool_value": (lambda v: v.bool_value), |
| 34 | + "string_value": (lambda v: v.string_value), |
| 35 | + "int_value": (lambda v: v.int_value), |
| 36 | + "double_value": (lambda v: v.double_value), |
| 37 | + "array_value": |
| 38 | + (lambda v: [decode_value(item) for item in v.array_value.values]), |
| 39 | + } |
| 40 | + for field, decoder in field_decoders.items(): |
| 41 | + if value.HasField(field): |
| 42 | + return decoder(value) |
| 43 | + raise ValueError(f"Couldn't decode value: {value}") |
| 44 | + |
| 45 | + |
| 46 | +def decode_attributes(attributes: Iterable[KeyValue]): |
| 47 | + return {kv.key: decode_value(kv.value) for kv in attributes} |
| 48 | + |
| 49 | + |
| 50 | +class FakeTraceService(TraceServiceServicer): |
| 51 | + |
| 52 | + def __init__(self): |
| 53 | + self.request = None |
| 54 | + self.evt = threading.Event() |
| 55 | + |
| 56 | + def Export(self, request, context): |
| 57 | + self.request = request |
| 58 | + self.evt.set() |
| 59 | + return ExportTraceServiceResponse() |
| 60 | + |
| 61 | + |
| 62 | +@pytest.fixture |
| 63 | +def trace_service() -> Generator[FakeTraceService, None, None]: |
| 64 | + """Fixture to set up a fake gRPC trace service""" |
| 65 | + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) |
| 66 | + service = FakeTraceService() |
| 67 | + add_TraceServiceServicer_to_server(service, server) |
| 68 | + server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) |
| 69 | + server.start() |
| 70 | + |
| 71 | + yield service |
| 72 | + |
| 73 | + server.stop(None) |
| 74 | + |
| 75 | + |
| 76 | +def test_traces( |
| 77 | + monkeypatch: pytest.MonkeyPatch, |
| 78 | + trace_service: FakeTraceService, |
| 79 | +): |
| 80 | + with monkeypatch.context() as m: |
| 81 | + m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") |
| 82 | + m.setenv("VLLM_USE_V1", "1") |
| 83 | + sampling_params = SamplingParams( |
| 84 | + temperature=0.01, |
| 85 | + top_p=0.1, |
| 86 | + max_tokens=256, |
| 87 | + ) |
| 88 | + model = "facebook/opt-125m" |
| 89 | + llm = LLM(model=model, |
| 90 | + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, |
| 91 | + gpu_memory_utilization=0.3, |
| 92 | + disable_log_stats=False) |
| 93 | + prompts = ["This is a short prompt"] |
| 94 | + outputs = llm.generate(prompts, sampling_params=sampling_params) |
| 95 | + print(f"test_traces outputs is : {outputs}") |
| 96 | + |
| 97 | + timeout = 10 |
| 98 | + if not trace_service.evt.wait(timeout): |
| 99 | + raise TimeoutError( |
| 100 | + f"The fake trace service didn't receive a trace within " |
| 101 | + f"the {timeout} seconds timeout") |
| 102 | + |
| 103 | + request = trace_service.request |
| 104 | + assert len(request.resource_spans) == 1, ( |
| 105 | + f"Expected 1 resource span, " |
| 106 | + f"but got {len(request.resource_spans)}") |
| 107 | + assert len(request.resource_spans[0].scope_spans) == 1, ( |
| 108 | + f"Expected 1 scope span, " |
| 109 | + f"but got {len(request.resource_spans[0].scope_spans)}") |
| 110 | + assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( |
| 111 | + f"Expected 1 span, " |
| 112 | + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") |
| 113 | + |
| 114 | + attributes = decode_attributes( |
| 115 | + request.resource_spans[0].scope_spans[0].spans[0].attributes) |
| 116 | + # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model |
| 117 | + assert attributes.get( |
| 118 | + SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id |
| 119 | + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE |
| 120 | + ) == sampling_params.temperature |
| 121 | + assert attributes.get( |
| 122 | + SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p |
| 123 | + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS |
| 124 | + ) == sampling_params.max_tokens |
| 125 | + assert attributes.get( |
| 126 | + SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n |
| 127 | + assert attributes.get( |
| 128 | + SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( |
| 129 | + outputs[0].prompt_token_ids) |
| 130 | + completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) |
| 131 | + assert attributes.get( |
| 132 | + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens |
| 133 | + |
| 134 | + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0 |
| 135 | + assert attributes.get( |
| 136 | + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0 |
| 137 | + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0 |
0 commit comments