Skip to content

Commit c4ca0f2

Browse files
authored
add user_id and team_id as log facets (#321)
* add user_id and team_id as log facets, refactor a little * fix lint, remove draft comments
1 parent c86357f commit c4ca0f2

File tree

4 files changed

+55
-23
lines changed

4 files changed

+55
-23
lines changed

model-engine/model_engine_server/api/app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from model_engine_server.api.tasks_v1 import inference_task_router_v1
2323
from model_engine_server.api.triggers_v1 import trigger_router_v1
2424
from model_engine_server.core.loggers import (
25+
LoggerTagKey,
26+
LoggerTagManager,
2527
filename_wo_ext,
26-
get_request_id,
2728
make_logger,
28-
set_request_id,
2929
)
3030

3131
logger = make_logger(filename_wo_ext(__name__))
@@ -47,11 +47,11 @@
4747
@app.middleware("http")
4848
async def dispatch(request: Request, call_next):
4949
try:
50-
set_request_id(str(uuid.uuid4()))
50+
LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4()))
5151
return await call_next(request)
5252
except Exception as e:
5353
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
54-
request_id = get_request_id()
54+
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
5555
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
5656
structured_log = {
5757
"error": str(e),

model-engine/model_engine_server/api/dependencies.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from model_engine_server.core.auth.fake_authentication_repository import (
1414
FakeAuthenticationRepository,
1515
)
16-
from model_engine_server.core.loggers import filename_wo_ext, make_logger
16+
from model_engine_server.core.loggers import (
17+
LoggerTagKey,
18+
LoggerTagManager,
19+
filename_wo_ext,
20+
make_logger,
21+
)
1722
from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync
1823
from model_engine_server.domain.gateways import (
1924
CronJobGateway,
@@ -330,6 +335,10 @@ async def verify_authentication(
330335
headers={"WWW-Authenticate": "Basic"},
331336
)
332337

338+
# set logger context with identity data
339+
LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id)
340+
LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id)
341+
333342
return auth
334343

335344

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
)
3737
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
3838
from model_engine_server.core.auth.authentication_repository import User
39-
from model_engine_server.core.loggers import filename_wo_ext, get_request_id, make_logger
39+
from model_engine_server.core.loggers import (
40+
LoggerTagKey,
41+
LoggerTagManager,
42+
filename_wo_ext,
43+
make_logger,
44+
)
4045
from model_engine_server.domain.exceptions import (
4146
EndpointDeleteFailedException,
4247
EndpointLabelsException,
@@ -82,7 +87,7 @@ def handle_streaming_exception(
8287
message: str,
8388
):
8489
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
85-
request_id = get_request_id()
90+
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
8691
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
8792
structured_log = {
8893
"error": message,
@@ -223,7 +228,7 @@ async def create_completion_sync_task(
223228
user=auth, model_endpoint_name=model_endpoint_name, request=request
224229
)
225230
except UpstreamServiceError:
226-
request_id = get_request_id()
231+
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
227232
logger.exception(f"Upstream service error for request {request_id}")
228233
raise HTTPException(
229234
status_code=500,

model-engine/model_engine_server/core/loggers.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import sys
66
import warnings
77
from contextlib import contextmanager
8-
from typing import Optional, Sequence
8+
from enum import Enum
9+
from typing import Dict, Optional, Sequence
910

1011
import ddtrace
1112
import json_log_formatter
@@ -16,8 +17,6 @@
1617
LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s"
1718
# REQUIRED FOR DATADOG COMPATIBILITY
1819

19-
ctx_var_request_id = contextvars.ContextVar("ctx_var_request_id", default=None)
20-
2120
__all__: Sequence[str] = (
2221
# most common imports
2322
"make_logger",
@@ -35,19 +34,37 @@
3534
"loggers_at_level",
3635
# utils
3736
"filename_wo_ext",
38-
"get_request_id",
39-
"set_request_id",
37+
"LoggerTagKey",
38+
"LoggerTagManager",
4039
)
4140

4241

43-
def get_request_id() -> Optional[str]:
44-
"""Get the request id from the context variable."""
45-
return ctx_var_request_id.get()
42+
class LoggerTagKey(str, Enum):
43+
REQUEST_ID = "request_id"
44+
TEAM_ID = "team_id"
45+
USER_ID = "user_id"
46+
47+
48+
class LoggerTagManager:
49+
_context_vars: Dict[LoggerTagKey, contextvars.ContextVar] = {}
4650

51+
@classmethod
52+
def get(cls, key: LoggerTagKey) -> Optional[str]:
53+
"""Get the value from the context variable."""
54+
ctx_var = cls._context_vars.get(key)
55+
if ctx_var is not None:
56+
return ctx_var.get()
57+
return None
4758

48-
def set_request_id(request_id: str) -> None:
49-
"""Set the request id in the context variable."""
50-
ctx_var_request_id.set(request_id) # type: ignore
59+
@classmethod
60+
def set(cls, key: LoggerTagKey, value: Optional[str]) -> None:
61+
"""Set the value in the context variable."""
62+
if value is not None:
63+
ctx_var = cls._context_vars.get(key)
64+
if ctx_var is None:
65+
ctx_var = contextvars.ContextVar(f"ctx_var_{key.name.lower()}", default=None)
66+
cls._context_vars[key] = ctx_var
67+
ctx_var.set(value)
5168

5269

5370
def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Logger:
@@ -77,10 +94,11 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d
7794
extra["lineno"] = record.lineno
7895
extra["pathname"] = record.pathname
7996

80-
# add the http request id if it exists
81-
request_id = ctx_var_request_id.get()
82-
if request_id:
83-
extra["request_id"] = request_id
97+
# add additional logger tags
98+
for tag_key in LoggerTagKey:
99+
tag_value = LoggerTagManager.get(tag_key)
100+
if tag_value:
101+
extra[tag_key.value] = tag_value
84102

85103
current_span = tracer.current_span()
86104
extra["dd.trace_id"] = current_span.trace_id if current_span else 0

0 commit comments

Comments
 (0)