Skip to content

Commit ee983a0

Browse files
committed
Bug fix and minor improvements
1 parent 583ad81 commit ee983a0

File tree

8 files changed

+850
-683
lines changed

8 files changed

+850
-683
lines changed

continuous_eval/eval/dataset.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
2+
import random
23
import typing
34
from dataclasses import dataclass
45
from pathlib import Path
6+
from string import ascii_lowercase, digits
57

68
import yaml
79

@@ -13,6 +15,15 @@
1315
_SAFE_DICT["ToolCall"] = ToolCall
1416

1517

18+
def _generate_uid():
19+
return "".join(random.choices(ascii_lowercase + digits, k=8))
20+
21+
22+
@dataclass(frozen=True)
23+
class LambdaField:
24+
func: typing.Callable
25+
26+
1627
@dataclass(frozen=True)
1728
class DatasetField:
1829
name: str
@@ -28,15 +39,15 @@ def to_dict(self):
2839
}
2940

3041

31-
@dataclass(frozen=True)
42+
@dataclass
3243
class DatasetManifest:
3344
name: str
3445
description: str
3546
format: str
3647
license: str
3748
fields: typing.Dict[str, DatasetField]
3849

39-
def to_yaml(self):
50+
def to_dict(self):
4051
return {
4152
"name": self.name,
4253
"description": self.description,
@@ -45,6 +56,24 @@ def to_yaml(self):
4556
"fields": {field_name: field.to_dict() for field_name, field in self.fields.items()},
4657
}
4758

59+
@classmethod
60+
def from_json(cls, data: typing.Dict):
61+
return cls(
62+
name=data.get("name", ""),
63+
description=data.get("description", ""),
64+
format=data.get("format", ""),
65+
license=data.get("license", ""),
66+
fields={
67+
field_name: DatasetField(
68+
name=field_name,
69+
type=eval(field_info["type"], _SAFE_DICT),
70+
description=field_info.get("description", ""),
71+
is_ground_truth=field_info.get("ground_truth", False),
72+
)
73+
for field_name, field_info in data["fields"].items()
74+
},
75+
)
76+
4877

4978
class Dataset:
5079
def __init__(
@@ -68,14 +97,22 @@ def __init__(
6897
# load jsonl dataset
6998
with open(dataset_path, "r") as json_file:
7099
self._data = [json.loads(x) for x in json_file.readlines()]
100+
for sample in self._data:
101+
sample["uid"] = UID(sample["uid"]) if "uid" in sample else _generate_uid()
71102
self._manifest = self._load_or_infer_manifest(manifest_path)
72103
self._create_dynamic_properties()
73104

74105
@classmethod
75-
def from_data(cls, data: typing.List[typing.Dict[str, typing.Any]]):
106+
def from_data(
107+
cls,
108+
data: typing.List[typing.Dict[str, typing.Any]],
109+
manifest: typing.Optional[typing.Dict] = None,
110+
):
76111
dataset = cls.__new__(cls)
77112
dataset._data = data
78-
dataset._manifest = dataset._infer_manifest()
113+
for sample in dataset._data:
114+
sample["uid"] = UID(sample["uid"]) if "uid" in sample else _generate_uid()
115+
dataset._manifest = DatasetManifest.from_json(manifest) if manifest is not None else dataset._infer_manifest()
79116
dataset._create_dynamic_properties()
80117
return dataset
81118

@@ -89,7 +126,7 @@ def save(self, file_path: typing.Union[str, Path], save_manifest: bool = False):
89126
if save_manifest:
90127
manifest_path = file_path.parent / "manifest.yaml"
91128
with open(manifest_path, "w") as manifest_file:
92-
manifest_file.write(yaml.dump(self._manifest.to_yaml()))
129+
manifest_file.write(yaml.dump(self._manifest.to_dict()))
93130

94131
def _load_or_infer_manifest(self, manifest_path: typing.Optional[Path]) -> DatasetManifest:
95132
if manifest_path is None or not manifest_path.exists():
@@ -147,6 +184,10 @@ def _create_dynamic_properties(self):
147184
def filed_types(self, name: str) -> type:
148185
return getattr(self, name).type
149186

187+
@property
188+
def manifest(self):
189+
return self._manifest
190+
150191
@property
151192
def data(self):
152193
return self._data
@@ -155,10 +196,18 @@ def data(self):
155196
def name(self):
156197
return self._manifest.name
157198

199+
@name.setter
200+
def name(self, value):
201+
self._manifest.name = value
202+
158203
@property
159204
def description(self):
160205
return self._manifest.description
161206

207+
@description.setter
208+
def description(self, value):
209+
self._manifest.description = value
210+
162211
@property
163212
def format(self):
164213
return self._manifest.format
@@ -167,13 +216,23 @@ def format(self):
167216
def license(self):
168217
return self._manifest.license
169218

219+
@license.setter
220+
def license(self, value):
221+
self._manifest.license = value
222+
170223
@property
171224
def fields(self) -> typing.List[DatasetField]:
172225
return list(self._manifest.fields.values())
173226

174227
def get_field(self, name: str) -> DatasetField:
175228
return self._manifest.fields[name]
176229

230+
def get_by_uid(self, uid: str) -> typing.Optional[typing.Dict]:
231+
for sample in self._data:
232+
if sample["uid"] == uid:
233+
return sample
234+
return None
235+
177236
def __getitem__(self, key: str):
178237
return [x[key] for x in self._data]
179238

continuous_eval/eval/modules.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,25 @@
33

44
from continuous_eval.eval.dataset import DatasetField
55
from continuous_eval.eval.tests import Test
6+
from continuous_eval.eval.utils import type_hint_to_str
67
from continuous_eval.metrics import Metric
78

89

10+
def _serialize_input_type(obj):
11+
if isinstance(obj, DatasetField):
12+
return {"__class__": obj.__class__.__name__, "name": obj.name}
13+
elif isinstance(obj, Module):
14+
return {"__class__": obj.__class__.__name__, "name": obj.name}
15+
elif isinstance(obj, type):
16+
return type_hint_to_str(obj)
17+
elif isinstance(obj, (list, tuple)):
18+
return [_serialize_input_type(x) for x in obj]
19+
elif obj is None:
20+
return "None"
21+
else:
22+
raise TypeError(f"Object of type {type(obj).__name__} is not serializable")
23+
24+
925
@dataclass(frozen=True, eq=True)
1026
class Tool:
1127
name: str
@@ -33,6 +49,16 @@ def __post_init__(self):
3349
eval_names = {metric.name for metric in self.eval}
3450
assert len(eval_names) == len(self.eval), f"Each metric name must be unique"
3551

52+
def asdict(self):
53+
return {
54+
"name": self.name,
55+
"input": _serialize_input_type(self.input),
56+
"output": type_hint_to_str(self.output),
57+
"description": self.description,
58+
"eval": [metric.asdict() for metric in self.eval] if self.eval else None,
59+
"tests": [test.asdict() for test in self.tests] if self.tests else None,
60+
}
61+
3662

3763
@dataclass(frozen=True, eq=True)
3864
class AgentModule(Module):

continuous_eval/eval/pipeline.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Callable, List, Optional, Set, Tuple
2+
from typing import Any, Callable, List, Optional, Set, Tuple, Union
33

44
from continuous_eval.eval.dataset import Dataset, DatasetField
55
from continuous_eval.eval.modules import Module, SingleModule
@@ -11,7 +11,7 @@
1111
@dataclass
1212
class ModuleOutput:
1313
selector: Callable = field(default=lambda x: x)
14-
module: Optional[Module] = None
14+
module: Optional[Union[Module, str]] = None
1515

1616
def __call__(self, *args: Any) -> Any:
1717
return self.selector(*args)
@@ -34,7 +34,7 @@ class Graph:
3434

3535

3636
class Pipeline:
37-
def __init__(self, modules: List[Module], dataset: Dataset) -> None:
37+
def __init__(self, modules: List[Module], dataset: Optional[Dataset] = None) -> None:
3838
self._modules = modules
3939
self._dataset = dataset
4040
self._graph = self._build_graph()
@@ -47,6 +47,10 @@ def modules(self):
4747
def dataset(self):
4848
return self._dataset
4949

50+
@dataset.setter
51+
def dataset(self, dataset: Dataset):
52+
self._dataset = dataset
53+
5054
def module_by_name(self, name: str) -> Module:
5155
for module in self._modules:
5256
if module.name == name:
@@ -71,6 +75,8 @@ def _validate_modules(self):
7175
names.add(module.name)
7276

7377
def _build_graph(self):
78+
if self._dataset is None:
79+
return None
7480
nodes = {m.name for m in self._modules}
7581
edges = set()
7682
dataset_edges = set()
@@ -96,6 +102,8 @@ def _build_graph(self):
96102
return Graph(nodes, edges, dataset_edges)
97103

98104
def graph_repr(self, with_type_hints: bool = False):
105+
if self._graph is None:
106+
return None
99107
repr_str = "graph TD;\n"
100108
dataset_node_label = "Dataset"
101109
repr_str += f" {dataset_node_label}(({dataset_node_label}));\n"
@@ -112,9 +120,14 @@ def graph_repr(self, with_type_hints: bool = False):
112120
repr_str += f' {dataset_node_label} -. "{dataset_field_name}" .-> {end_node};\n'
113121
return repr_str
114122

123+
def asdict(self):
124+
return {
125+
"modules": [m.asdict() for m in self._modules],
126+
}
127+
115128

116129
def SingleModulePipeline(
117-
dataset: Dataset,
130+
dataset: Optional[Dataset] = None,
118131
eval: Optional[List[Metric]] = None,
119132
tests: Optional[List[Test]] = None,
120133
name: str = "eval",

continuous_eval/eval/runner.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import Optional, Union
33

4-
from continuous_eval.eval.dataset import Dataset, DatasetField
4+
from continuous_eval.eval.dataset import Dataset, DatasetField, LambdaField
55
from continuous_eval.eval.logger import PipelineLogger
66
from continuous_eval.eval.modules import Module
77
from continuous_eval.eval.pipeline import CalledTools, ModuleOutput, Pipeline
@@ -30,22 +30,49 @@ def dataset(self) -> Dataset:
3030
return self._pipeline.dataset
3131

3232
# Evaluate
33-
34-
def _prepare(self, eval_results: PipelineResults, module: Module, metric: Metric):
33+
@staticmethod
34+
def prepare(dataset: Dataset, eval_results: PipelineResults, module: Module, metric: Metric):
3535
kwargs = dict()
3636
if metric.overloaded_params is not None:
3737
for key, val in metric.overloaded_params.items():
38+
if key == "uid":
39+
continue
3840
if isinstance(val, DatasetField):
39-
kwargs[key] = [x[val.name] for x in self.dataset.data] # type: ignore
41+
kwargs[key] = [x[module.name][val.name] if module.name in x else x[val.name] for x in dataset.data] # type: ignore
42+
elif isinstance(val, LambdaField):
43+
kwargs[key] = list()
44+
for rx in eval_results.results:
45+
uid = rx["uid"]
46+
if module.name in rx:
47+
kwargs[key].append(val.func(rx[module.name]))
48+
else:
49+
for x in dataset.data:
50+
if x["uid"] == uid:
51+
kwargs[key].append(val.func(x))
52+
break
53+
# kwargs[key] = [
54+
# val.func(x[module.name]) if module.name in x else val.func(x)
55+
# for x in dataset.data
56+
# ]
4057
elif isinstance(val, ModuleOutput):
41-
module_name = module.name if val.module is None else val.module.name
58+
module_name = module.name if val.module is None else val.module
59+
if isinstance(val, Module):
60+
module_name = val.name
4261
kwargs[key] = [val(x[module_name]) for x in eval_results.results]
4362
elif isinstance(val, CalledTools):
4463
module_name = module.name if val.module is None else val.module.name
4564
val_key = f"{TOOL_PREFIX}{module_name}"
4665
kwargs[key] = [val(x[val_key]) for x in eval_results.results]
4766
else:
4867
raise ValueError(f"Invalid promised parameter {key}={val}")
68+
return kwargs
69+
else:
70+
for item in eval_results.results:
71+
itr = item[module.name] if module.name in item else item
72+
for key, value in itr.items():
73+
if key not in kwargs:
74+
kwargs[key] = []
75+
kwargs[key].append(value)
4976
return kwargs
5077

5178
@telemetry_event("eval_manager")
@@ -67,7 +94,8 @@ def evaluate(
6794
metrics_results = MetricsResults(self.pipeline)
6895
metrics_results.samples = {
6996
module.name: {
70-
metric.name: metric.batch(**self._prepare(eval_results, module, metric)) for metric in module.eval
97+
metric.name: metric.batch(**self.prepare(self.dataset, eval_results, module, metric))
98+
for metric in module.eval
7199
}
72100
for module in self._pipeline.modules
73101
if module.eval is not None

continuous_eval/eval/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ def run(self, metrics_per_sample) -> bool:
2020
"""
2121
raise NotImplementedError
2222

23+
def asdict(self):
24+
return {
25+
"__class__": self.__class__.__name__,
26+
"name": self.name,
27+
}
28+
2329

2430
# Some common tests
2531
class GreaterOrEqualThan(Test):

continuous_eval/metrics/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def aggregate(self, results: List[Any]) -> Any:
5858
def name(self):
5959
return self.__class__.__name__
6060

61+
def asdict(self):
62+
return {
63+
"__class__": self.__class__.__name__,
64+
"name": self.name,
65+
}
66+
6167

6268
class LLMBasedMetric(Metric):
6369
"""

0 commit comments

Comments
 (0)