11import json
2+ import random
23import typing
34from dataclasses import dataclass
45from pathlib import Path
6+ from string import ascii_lowercase , digits
57
68import yaml
79
1315_SAFE_DICT ["ToolCall" ] = ToolCall
1416
1517
18+ def _generate_uid ():
19+ return "" .join (random .choices (ascii_lowercase + digits , k = 8 ))
20+
21+
22+ @dataclass (frozen = True )
23+ class LambdaField :
24+ func : typing .Callable
25+
26+
1627@dataclass (frozen = True )
1728class DatasetField :
1829 name : str
@@ -28,15 +39,15 @@ def to_dict(self):
2839 }
2940
3041
31- @dataclass ( frozen = True )
42+ @dataclass
3243class DatasetManifest :
3344 name : str
3445 description : str
3546 format : str
3647 license : str
3748 fields : typing .Dict [str , DatasetField ]
3849
39- def to_yaml (self ):
50+ def to_dict (self ):
4051 return {
4152 "name" : self .name ,
4253 "description" : self .description ,
@@ -45,6 +56,24 @@ def to_yaml(self):
4556 "fields" : {field_name : field .to_dict () for field_name , field in self .fields .items ()},
4657 }
4758
59+ @classmethod
60+ def from_json (cls , data : typing .Dict ):
61+ return cls (
62+ name = data .get ("name" , "" ),
63+ description = data .get ("description" , "" ),
64+ format = data .get ("format" , "" ),
65+ license = data .get ("license" , "" ),
66+ fields = {
67+ field_name : DatasetField (
68+ name = field_name ,
69+ type = eval (field_info ["type" ], _SAFE_DICT ),
70+ description = field_info .get ("description" , "" ),
71+ is_ground_truth = field_info .get ("ground_truth" , False ),
72+ )
73+ for field_name , field_info in data ["fields" ].items ()
74+ },
75+ )
76+
4877
4978class Dataset :
5079 def __init__ (
@@ -68,14 +97,22 @@ def __init__(
6897 # load jsonl dataset
6998 with open (dataset_path , "r" ) as json_file :
7099 self ._data = [json .loads (x ) for x in json_file .readlines ()]
100+ for sample in self ._data :
101+ sample ["uid" ] = UID (sample ["uid" ]) if "uid" in sample else _generate_uid ()
71102 self ._manifest = self ._load_or_infer_manifest (manifest_path )
72103 self ._create_dynamic_properties ()
73104
74105 @classmethod
75- def from_data (cls , data : typing .List [typing .Dict [str , typing .Any ]]):
106+ def from_data (
107+ cls ,
108+ data : typing .List [typing .Dict [str , typing .Any ]],
109+ manifest : typing .Optional [typing .Dict ] = None ,
110+ ):
76111 dataset = cls .__new__ (cls )
77112 dataset ._data = data
78- dataset ._manifest = dataset ._infer_manifest ()
113+ for sample in dataset ._data :
114+ sample ["uid" ] = UID (sample ["uid" ]) if "uid" in sample else _generate_uid ()
115+ dataset ._manifest = DatasetManifest .from_json (manifest ) if manifest is not None else dataset ._infer_manifest ()
79116 dataset ._create_dynamic_properties ()
80117 return dataset
81118
@@ -89,7 +126,7 @@ def save(self, file_path: typing.Union[str, Path], save_manifest: bool = False):
89126 if save_manifest :
90127 manifest_path = file_path .parent / "manifest.yaml"
91128 with open (manifest_path , "w" ) as manifest_file :
92- manifest_file .write (yaml .dump (self ._manifest .to_yaml ()))
129+ manifest_file .write (yaml .dump (self ._manifest .to_dict ()))
93130
94131 def _load_or_infer_manifest (self , manifest_path : typing .Optional [Path ]) -> DatasetManifest :
95132 if manifest_path is None or not manifest_path .exists ():
@@ -147,6 +184,10 @@ def _create_dynamic_properties(self):
147184 def filed_types (self , name : str ) -> type :
148185 return getattr (self , name ).type
149186
187+ @property
188+ def manifest (self ):
189+ return self ._manifest
190+
150191 @property
151192 def data (self ):
152193 return self ._data
@@ -155,10 +196,18 @@ def data(self):
155196 def name (self ):
156197 return self ._manifest .name
157198
199+ @name .setter
200+ def name (self , value ):
201+ self ._manifest .name = value
202+
158203 @property
159204 def description (self ):
160205 return self ._manifest .description
161206
207+ @description .setter
208+ def description (self , value ):
209+ self ._manifest .description = value
210+
162211 @property
163212 def format (self ):
164213 return self ._manifest .format
@@ -167,13 +216,23 @@ def format(self):
167216 def license (self ):
168217 return self ._manifest .license
169218
219+ @license .setter
220+ def license (self , value ):
221+ self ._manifest .license = value
222+
170223 @property
171224 def fields (self ) -> typing .List [DatasetField ]:
172225 return list (self ._manifest .fields .values ())
173226
174227 def get_field (self , name : str ) -> DatasetField :
175228 return self ._manifest .fields [name ]
176229
230+ def get_by_uid (self , uid : str ) -> typing .Optional [typing .Dict ]:
231+ for sample in self ._data :
232+ if sample ["uid" ] == uid :
233+ return sample
234+ return None
235+
177236 def __getitem__ (self , key : str ):
178237 return [x [key ] for x in self ._data ]
179238
0 commit comments