|
11 | 11 | from litdata.utilities.train_test_split import deepcopy_dataset |
12 | 12 | import copy |
13 | 13 |
|
14 | | -from ..processors import get_processor |
| 14 | +from ..processors import get_processor, IgnoreProcessor |
15 | 15 | from ..processors.base_processor import FeatureProcessor |
16 | 16 |
|
17 | 17 |
|
@@ -191,8 +191,14 @@ def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]: |
191 | 191 | transformed: Dict[str, Any] = {} |
192 | 192 | for key, value in pickle.loads(sample["sample"]).items(): |
193 | 193 | if key in self._input_processors: |
| 194 | + # Skip ignored features |
| 195 | + if isinstance(self._input_processors[key], IgnoreProcessor): |
| 196 | + continue |
194 | 197 | transformed[key] = self._input_processors[key].process(value) |
195 | 198 | elif key in self._output_processors: |
| 199 | + # Skip ignored features |
| 200 | + if isinstance(self._output_processors[key], IgnoreProcessor): |
| 201 | + continue |
196 | 202 | transformed[key] = self._output_processors[key].process(value) |
197 | 203 | else: |
198 | 204 | transformed[key] = value |
@@ -221,6 +227,30 @@ def save(self, path: str) -> None: |
221 | 227 | with open(path, "wb") as f: |
222 | 228 | pickle.dump(metadata, f) |
223 | 229 |
|
| 230 | + @staticmethod |
| 231 | + def load(path: str) -> "SampleBuilder": |
| 232 | + """Load a SampleBuilder from a pickled metadata file. |
| 233 | +
|
| 234 | + Args: |
| 235 | + path: Location of the pickled metadata file (commonly named `schema.pkl`). |
| 236 | +
|
| 237 | + Returns: |
| 238 | + A SampleBuilder instance with loaded metadata. |
| 239 | + """ |
| 240 | + with open(path, "rb") as f: |
| 241 | + metadata = pickle.load(f) |
| 242 | + |
| 243 | + builder = SampleBuilder( |
| 244 | + input_schema=metadata["input_schema"], |
| 245 | + output_schema=metadata["output_schema"], |
| 246 | + ) |
| 247 | + builder._input_processors = metadata["input_processors"] |
| 248 | + builder._output_processors = metadata["output_processors"] |
| 249 | + builder._patient_to_index = metadata["patient_to_index"] |
| 250 | + builder._record_to_index = metadata["record_to_index"] |
| 251 | + builder._fitted = True |
| 252 | + return builder |
| 253 | + |
224 | 254 |
|
225 | 255 | class SampleDataset(litdata.StreamingDataset): |
226 | 256 | """A streaming dataset that loads sample metadata and processors from disk. |
@@ -276,10 +306,29 @@ def __init__( |
276 | 306 | self.output_schema = metadata["output_schema"] |
277 | 307 | self.input_processors = metadata["input_processors"] |
278 | 308 | self.output_processors = metadata["output_processors"] |
| 309 | + self._remove_ignored_processors() |
279 | 310 |
|
280 | 311 | self.patient_to_index = metadata["patient_to_index"] |
281 | 312 | self.record_to_index = metadata["record_to_index"] |
282 | 313 |
|
| 314 | + def _remove_ignored_processors(self): |
| 315 | + """Remove any processors that are IgnoreProcessor instances.""" |
| 316 | + for key in [ |
| 317 | + key |
| 318 | + for key, proc in self.input_processors.items() |
| 319 | + if isinstance(proc, IgnoreProcessor) |
| 320 | + ]: |
| 321 | + del self.input_processors[key] |
| 322 | + del self.input_schema[key] |
| 323 | + |
| 324 | + for key in [ |
| 325 | + key |
| 326 | + for key, proc in self.output_processors.items() |
| 327 | + if isinstance(proc, IgnoreProcessor) |
| 328 | + ]: |
| 329 | + del self.output_processors[key] |
| 330 | + del self.output_schema[key] |
| 331 | + |
283 | 332 | def __str__(self) -> str: |
284 | 333 | """Returns a string representation of the dataset. |
285 | 334 |
|
@@ -356,12 +405,12 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": |
356 | 405 | new_dataset.reset() |
357 | 406 |
|
358 | 407 | return new_dataset |
359 | | - |
| 408 | + |
360 | 409 | def close(self) -> None: |
361 | 410 | """Cleans up any temporary directories used by the dataset.""" |
362 | 411 | if self.input_dir.path is not None and Path(self.input_dir.path).exists(): |
363 | 412 | shutil.rmtree(self.input_dir.path) |
364 | | - |
| 413 | + |
365 | 414 | # -------------------------------------------------------------- |
366 | 415 | # Context manager support |
367 | 416 | # -------------------------------------------------------------- |
@@ -426,6 +475,7 @@ def __init__( |
426 | 475 | self.output_schema = builder.output_schema |
427 | 476 | self.input_processors = builder.input_processors |
428 | 477 | self.output_processors = builder.output_processors |
| 478 | + self._remove_ignored_processors() |
429 | 479 |
|
430 | 480 | self.patient_to_index = builder.patient_to_index |
431 | 481 | self.record_to_index = builder.record_to_index |
@@ -482,6 +532,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: |
482 | 532 | def close(self) -> None: |
483 | 533 | pass # No temporary directories to clean up for in-memory dataset |
484 | 534 |
|
| 535 | + |
485 | 536 | def create_sample_dataset( |
486 | 537 | samples: List[Dict[str, Any]], |
487 | 538 | input_schema: Dict[str, Any], |
|
0 commit comments