|
14 | 14 |
|
15 | 15 |
|
16 | 16 | import time
|
| 17 | +from contextvars import Token |
17 | 18 | from dataclasses import dataclass, field
|
18 | 19 | from enum import Enum
|
19 | 20 | from typing import Any, Dict, List, Literal, Optional, Type, Union
|
20 | 21 | from uuid import UUID, uuid4
|
21 | 22 |
|
| 23 | +from typing_extensions import TypeAlias |
| 24 | + |
| 25 | +from opentelemetry.context import Context |
22 | 26 | from opentelemetry.trace import Span
|
23 | 27 | from opentelemetry.util.types import AttributeValue
|
24 | 28 |
|
| 29 | +ContextToken: TypeAlias = Token[Context] |
| 30 | + |
25 | 31 |
|
26 | 32 | class ContentCapturingMode(Enum):
|
27 | 33 | # Do not capture content (default).
|
@@ -76,34 +82,46 @@ class OutputMessage:
|
76 | 82 | finish_reason: Union[str, FinishReason]
|
77 | 83 |
|
78 | 84 |
|
| 85 | +def _new_input_messages() -> list[InputMessage]: |
| 86 | + return [] |
| 87 | + |
| 88 | + |
| 89 | +def _new_output_messages() -> list[OutputMessage]: |
| 90 | + return [] |
| 91 | + |
| 92 | + |
| 93 | +def _new_str_any_dict() -> dict[str, Any]: |
| 94 | + return {} |
| 95 | + |
| 96 | + |
79 | 97 | @dataclass
|
80 | 98 | class LLMInvocation:
|
81 | 99 | """
|
82 |
| - Represents a single LLM call invocation. |
83 |
| - Added optional fields (run_id, parent_run_id, messages, chat_generations) to |
84 |
| - interoperate with advanced generators (SpanMetricGenerator, SpanMetricEventGenerator). |
| 100 | + Represents a single LLM call invocation. When creating an LLMInvocation object, |
| 101 | + only update the data attributes. The span and context_token attributes are |
| 102 | + set by the TelemetryHandler. |
85 | 103 | """
|
86 | 104 |
|
87 | 105 | request_model: str
|
88 |
| - # Stores either a contextvars Token or a context manager (use_span) kept open until finish/error. |
89 |
| - context_token: Optional[Any] = None |
| 106 | + context_token: Optional[ContextToken] = None |
90 | 107 | span: Optional[Span] = None
|
91 | 108 | start_time: float = field(default_factory=time.time)
|
92 | 109 | end_time: Optional[float] = None
|
93 |
| - input_messages: List[InputMessage] = field(default_factory=list) |
94 |
| - output_messages: List[OutputMessage] = field(default_factory=list) |
| 110 | + input_messages: List[InputMessage] = field( |
| 111 | + default_factory=_new_input_messages |
| 112 | + ) |
| 113 | + output_messages: List[OutputMessage] = field( |
| 114 | + default_factory=_new_output_messages |
| 115 | + ) |
95 | 116 | provider: Optional[str] = None
|
96 | 117 | response_model_name: Optional[str] = None
|
97 | 118 | response_id: Optional[str] = None
|
98 | 119 | input_tokens: Optional[AttributeValue] = None
|
99 | 120 | output_tokens: Optional[AttributeValue] = None
|
100 |
| - attributes: Dict[str, Any] = field(default_factory=dict) |
101 |
| - # Advanced generator compatibility fields |
| 121 | + attributes: Dict[str, Any] = field(default_factory=_new_str_any_dict) |
| 122 | + # Ahead of upstream |
102 | 123 | run_id: UUID = field(default_factory=uuid4)
|
103 | 124 | parent_run_id: Optional[UUID] = None
|
104 |
| - # Unified views expected by span_metric* generators |
105 |
| - messages: List[InputMessage] = field(default_factory=list) |
106 |
| - chat_generations: List[OutputMessage] = field(default_factory=list) |
107 | 125 |
|
108 | 126 |
|
109 | 127 | @dataclass
|
|
0 commit comments