|
17 | 17 | import time |
18 | 18 | import uuid |
19 | 19 | from abc import abstractmethod |
20 | | -from typing import Any |
| 20 | +from typing import Any, Optional |
21 | 21 |
|
22 | 22 | from google.adk import Runner |
23 | 23 | from google.adk.evaluation.eval_set import EvalSet |
@@ -210,33 +210,43 @@ def _build_eval_set_from_tracing_json(self, tracing_json_path: str) -> EvalSet: |
210 | 210 |
|
211 | 211 | return evalset |
212 | 212 |
|
213 | | - def build_eval_set(self, file_path: str): |
| 213 | + def build_eval_set( |
| 214 | + self, eval_set: Optional[EvalSet] = None, file_path: Optional[str] = None |
| 215 | + ): |
214 | 216 | """Generate evaluation data from a given file and assign it to the class attribute `invocation_list`.""" |
215 | | - eval_case_data_list: list[EvalTestCase] = [] |
216 | 217 |
|
217 | | - try: |
218 | | - with open(file_path, "r", encoding="utf-8") as f: |
219 | | - file_content = json.load(f) |
220 | | - except json.JSONDecodeError as e: |
221 | | - raise ValueError(f"Invalid JSON format in file {file_path}: {e}") |
222 | | - except Exception as e: |
223 | | - raise ValueError(f"Error reading file {file_path}: {e}") |
224 | | - |
225 | | - if isinstance(file_content, dict) and "eval_cases" in file_content: |
226 | | - eval_cases = self._build_eval_set_from_eval_json(file_path).eval_cases |
227 | | - elif ( |
228 | | - isinstance(file_content, list) |
229 | | - and len(file_content) > 0 |
230 | | - and all( |
231 | | - isinstance(span, dict) and "trace_id" in span for span in file_content |
232 | | - ) |
233 | | - ): |
234 | | - eval_cases = self._build_eval_set_from_tracing_json(file_path).eval_cases |
| 218 | + if eval_set is None and file_path is None: |
| 219 | + raise ValueError("eval_set or file_path is required") |
| 220 | + if eval_set: |
| 221 | + eval_cases = eval_set.eval_cases |
235 | 222 | else: |
236 | | - raise ValueError( |
237 | | - f"Unsupported file format in {file_path}. Please provide a valid file." |
238 | | - ) |
| 223 | + try: |
| 224 | + with open(file_path, "r", encoding="utf-8") as f: |
| 225 | + file_content = json.load(f) |
| 226 | + except json.JSONDecodeError as e: |
| 227 | + raise ValueError(f"Invalid JSON format in file {file_path}: {e}") |
| 228 | + except Exception as e: |
| 229 | + raise ValueError(f"Error reading file {file_path}: {e}") |
| 230 | + |
| 231 | + if isinstance(file_content, dict) and "eval_cases" in file_content: |
| 232 | + eval_cases = self._build_eval_set_from_eval_json(file_path).eval_cases |
| 233 | + elif ( |
| 234 | + isinstance(file_content, list) |
| 235 | + and len(file_content) > 0 |
| 236 | + and all( |
| 237 | + isinstance(span, dict) and "trace_id" in span |
| 238 | + for span in file_content |
| 239 | + ) |
| 240 | + ): |
| 241 | + eval_cases = self._build_eval_set_from_tracing_json( |
| 242 | + file_path |
| 243 | + ).eval_cases |
| 244 | + else: |
| 245 | + raise ValueError( |
| 246 | + f"Unsupported file format in {file_path}. Please provide a valid file." |
| 247 | + ) |
239 | 248 |
|
| 249 | + eval_case_data_list: list[EvalTestCase] = [] |
240 | 250 | for eval_case in eval_cases: |
241 | 251 | eval_case_data = EvalTestCase(invocations=[]) |
242 | 252 | if eval_case.session_input: |
@@ -384,8 +394,9 @@ def get_eval_set_information(self) -> list[list[dict[str, Any]]]: |
384 | 394 | @abstractmethod |
385 | 395 | async def evaluate( |
386 | 396 | self, |
387 | | - eval_set_file_path: str, |
388 | 397 | metrics: list[Any], |
| 398 | + eval_set: Optional[EvalSet], |
| 399 | + eval_set_file_path: Optional[str], |
389 | 400 | eval_id: str, |
390 | 401 | ): |
391 | 402 | """An abstract method for evaluation based on metrics。""" |
|
0 commit comments