Skip to content

Commit ba74cf8

Browse files
baz review fixes
1 parent 853c790 commit ba74cf8

File tree

1 file changed

+68
-43
lines changed

1 file changed

+68
-43
lines changed

src/galileo/__future__/metric.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ def my_scorer(trace_or_span):
119119
updated_at: datetime | None
120120
version: int | None
121121

122+
# Scorer defaults - available for LLM and built-in Galileo metrics
123+
# These are returned by the API in the ScorerDefaults object
124+
model: str | None
125+
judges: int | None
126+
cot_enabled: bool | None
127+
122128
def __init__(
123129
self, name: str, *, description: str = "", tags: list[str] | None = None, version: int | None = None
124130
) -> None:
@@ -140,8 +146,42 @@ def __init__(
140146
self.created_at = None
141147
self.updated_at = None
142148
self.scorer_type = None
149+
150+
# Initialize scorer defaults (populated from API for LLM and Galileo metrics)
151+
self.model = None
152+
self.judges = None
153+
self.cot_enabled = None
154+
143155
self._set_state(SyncState.LOCAL_ONLY)
144156

157+
@classmethod
158+
def _create_metric_from_type(cls, scorer_type: ScorerTypes) -> Metric:
159+
"""
160+
Create the appropriate Metric subclass instance based on scorer_type.
161+
162+
This is a factory method that centralizes the logic for instantiating
163+
the correct metric subclass based on the scorer type returned from the API.
164+
165+
Args:
166+
scorer_type: The scorer type from the API response.
167+
168+
Returns
169+
-------
170+
Metric: An uninitialized instance of the appropriate subclass
171+
(LlmMetric, CodeMetric, or GalileoMetric).
172+
173+
Examples
174+
--------
175+
instance = Metric._create_metric_from_type(ScorerTypes.LLM)
176+
# Returns: LlmMetric instance
177+
"""
178+
if scorer_type == ScorerTypes.LLM:
179+
return LlmMetric.__new__(LlmMetric)
180+
if scorer_type == ScorerTypes.CODE:
181+
return CodeMetric.__new__(CodeMetric)
182+
# Default to GalileoMetric for built-in scorers (LUNA, PRESET, etc.)
183+
return GalileoMetric.__new__(GalileoMetric)
184+
145185
@classmethod
146186
def get(cls, *, id: str | None = None, name: str | None = None) -> Metric | None:
147187
"""
@@ -190,17 +230,8 @@ def get(cls, *, id: str | None = None, name: str | None = None) -> Metric | None
190230
if retrieved_scorer is None:
191231
return None
192232

193-
# Determine appropriate subclass based on scorer_type
194-
scorer_type = retrieved_scorer.scorer_type
195-
instance: Metric
196-
if scorer_type == ScorerTypes.LLM:
197-
instance = LlmMetric.__new__(LlmMetric)
198-
elif scorer_type == ScorerTypes.CODE:
199-
instance = CodeMetric.__new__(CodeMetric)
200-
else:
201-
# Default to GalileoMetric for built-in scorers
202-
instance = GalileoMetric.__new__(GalileoMetric)
203-
233+
# Create appropriate subclass instance based on scorer_type
234+
instance = cls._create_metric_from_type(retrieved_scorer.scorer_type)
204235
StateManagementMixin.__init__(instance)
205236
instance._populate_from_scorer_response(retrieved_scorer)
206237
instance._set_state(SyncState.SYNCED)
@@ -241,17 +272,8 @@ def list(
241272

242273
result: builtins.list[Metric] = []
243274
for retrieved_scorer in retrieved_scorers:
244-
# Determine appropriate subclass based on scorer_type
245-
scorer_type = retrieved_scorer.scorer_type
246-
instance: Metric
247-
if scorer_type == ScorerTypes.LLM:
248-
instance = LlmMetric.__new__(LlmMetric)
249-
elif scorer_type == ScorerTypes.CODE:
250-
instance = CodeMetric.__new__(CodeMetric)
251-
else:
252-
# Default to GalileoMetric for built-in scorers
253-
instance = GalileoMetric.__new__(GalileoMetric)
254-
275+
# Create appropriate subclass instance based on scorer_type
276+
instance = cls._create_metric_from_type(retrieved_scorer.scorer_type)
255277
StateManagementMixin.__init__(instance)
256278
instance._populate_from_scorer_response(retrieved_scorer)
257279
instance._set_state(SyncState.SYNCED)
@@ -310,27 +332,28 @@ def _populate_from_scorer_response(self, scorer_response: Any) -> None:
310332
self.created_at = None if isinstance(scorer_response.created_at, Unset) else scorer_response.created_at
311333
self.updated_at = None if isinstance(scorer_response.updated_at, Unset) else scorer_response.updated_at
312334

335+
# Extract defaults - available for LLM and built-in Galileo metrics
336+
# These are returned by the API for preset scorers too
337+
if not isinstance(scorer_response.defaults, Unset) and scorer_response.defaults is not None:
338+
self.model = (
339+
scorer_response.defaults.model_name if hasattr(scorer_response.defaults, "model_name") else None
340+
)
341+
self.judges = (
342+
scorer_response.defaults.num_judges if hasattr(scorer_response.defaults, "num_judges") else None
343+
)
344+
self.cot_enabled = (
345+
scorer_response.defaults.cot_enabled if hasattr(scorer_response.defaults, "cot_enabled") else None
346+
)
347+
else:
348+
self.model = None
349+
self.judges = None
350+
self.cot_enabled = None
351+
313352
# LLM-specific attributes (only set if this is an LlmMetric)
314353
if isinstance(self, LlmMetric):
315354
self.output_type = None if isinstance(scorer_response.output_type, Unset) else scorer_response.output_type
316355
self.prompt = None if isinstance(scorer_response.user_prompt, Unset) else scorer_response.user_prompt
317356

318-
# Extract defaults
319-
if not isinstance(scorer_response.defaults, Unset) and scorer_response.defaults is not None:
320-
self.model = (
321-
scorer_response.defaults.model_name if hasattr(scorer_response.defaults, "model_name") else None
322-
)
323-
self.judges = (
324-
scorer_response.defaults.num_judges if hasattr(scorer_response.defaults, "num_judges") else None
325-
)
326-
self.cot_enabled = (
327-
scorer_response.defaults.cot_enabled if hasattr(scorer_response.defaults, "cot_enabled") else None
328-
)
329-
else:
330-
self.model = None
331-
self.judges = None
332-
self.cot_enabled = None
333-
334357
# Extract scoreable node types
335358
if not isinstance(scorer_response.scoreable_node_types, Unset) and scorer_response.scoreable_node_types:
336359
try:
@@ -628,10 +651,10 @@ def create(self) -> LlmMetric:
628651
created_version = metrics_service.create_custom_llm_metric(
629652
name=self.name,
630653
user_prompt=self.prompt or "",
631-
node_level=self.node_level or StepType.llm,
632-
cot_enabled=self.cot_enabled or True,
633-
model_name=self.model or Configuration.default_scorer_model,
634-
num_judges=self.judges or Configuration.default_scorer_judges,
654+
node_level=self.node_level if self.node_level is not None else StepType.llm,
655+
cot_enabled=self.cot_enabled if self.cot_enabled is not None else True,
656+
model_name=self.model if self.model is not None else Configuration.default_scorer_model,
657+
num_judges=self.judges if self.judges is not None else Configuration.default_scorer_judges,
635658
description=self.description,
636659
tags=self.tags,
637660
output_type=self.output_type
@@ -840,4 +863,6 @@ def my_scorer(trace):
840863

841864
def __repr__(self) -> str:
842865
"""Detailed string representation of the metric."""
843-
return f"LocalMetric(name='{self.name}', scorer_fn={self.scorer_fn.__name__})"
866+
# Handle callables that don't have __name__ (partials, lambdas, callable instances)
867+
fn_name = getattr(self.scorer_fn, "__name__", f"<{type(self.scorer_fn).__name__}>")
868+
return f"LocalMetric(name='{self.name}', scorer_fn={fn_name})"

0 commit comments

Comments
 (0)