|
1 |
| -from typing import Any |
| 1 | +from typing import Any, Optional |
2 | 2 |
|
| 3 | +import os |
3 | 4 | import yaml
|
4 | 5 | from loguru import logger
|
5 | 6 | from pydantic import BaseModel, ConfigDict
|
| 7 | +from enum import Enum |
| 8 | + |
| 9 | +from guidellm.utils import is_file_name |
| 10 | + |
| 11 | + |
| 12 | +__all__ = ["Serializable", "SerializableFileType"] |
| 13 | + |
| 14 | + |
| 15 | +class SerializableFileType(Enum): |
| 16 | + """ |
| 17 | + Enum class for file types supported by Serializable. |
| 18 | + """ |
| 19 | + |
| 20 | + YAML = "yaml" |
| 21 | + JSON = "json" |
6 | 22 |
|
7 | 23 |
|
8 | 24 | class Serializable(BaseModel):
|
@@ -73,3 +89,97 @@ def from_json(cls, data: str):
|
73 | 89 | obj = cls.model_validate_json(data)
|
74 | 90 |
|
75 | 91 | return obj
|
| 92 | + |
| 93 | + def save_file(self, path: str, type_: Optional[SerializableFileType] = None) -> str: |
| 94 | + """ |
| 95 | + Save the model to a file in either YAML or JSON format. |
| 96 | +
|
| 97 | + :param path: Path to the exact file or the containing directory. |
| 98 | + If it is a directory, the file name will be inferred from the class name. |
| 99 | + :param type_: Optional type to save ('yaml' or 'json'). |
| 100 | + If not provided and the path has an extension, |
| 101 | + it will be inferred to save in that format. |
| 102 | + If not provided and the path does not have an extension, |
| 103 | + it will save in YAML format. |
| 104 | + :return: The path to the saved file. |
| 105 | + """ |
| 106 | + logger.debug("Saving to file... {} with format: {}", path, type_) |
| 107 | + |
| 108 | + if not is_file_name(path): |
| 109 | + file_name = f"{self.__class__.__name__.lower()}" |
| 110 | + if type_: |
| 111 | + file_name += f".{type_.value.lower()}" |
| 112 | + else: |
| 113 | + file_name += ".yaml" |
| 114 | + type_ = SerializableFileType.YAML |
| 115 | + path = os.path.join(path, file_name) |
| 116 | + |
| 117 | + if not type_: |
| 118 | + extension = path.split(".")[-1].upper() |
| 119 | + |
| 120 | + if extension not in SerializableFileType.__members__: |
| 121 | + raise ValueError( |
| 122 | + f"Unsupported file extension: {extension}. " |
| 123 | + f"Expected one of {', '.join(SerializableFileType.__members__)}) " |
| 124 | + f"for {path}" |
| 125 | + ) |
| 126 | + |
| 127 | + type_ = SerializableFileType[extension] |
| 128 | + |
| 129 | + if type_.name not in SerializableFileType.__members__: |
| 130 | + raise ValueError( |
| 131 | + f"Unsupported file format: {type_} " |
| 132 | + f"(expected 'yaml' or 'json') for {path}" |
| 133 | + ) |
| 134 | + |
| 135 | + os.makedirs(os.path.dirname(path), exist_ok=True) |
| 136 | + |
| 137 | + with open(path, "w") as file: |
| 138 | + if type_ == SerializableFileType.YAML: |
| 139 | + file.write(self.to_yaml()) |
| 140 | + elif type_ == SerializableFileType.JSON: |
| 141 | + file.write(self.to_json()) |
| 142 | + else: |
| 143 | + raise ValueError(f"Unsupported file format: {type_}") |
| 144 | + |
| 145 | + logger.info("Successfully saved {} to {}", self.__class__.__name__, path) |
| 146 | + |
| 147 | + return path |
| 148 | + |
| 149 | + @classmethod |
| 150 | + def load_file(cls, path: str): |
| 151 | + """ |
| 152 | + Load a model from a file in either YAML or JSON format. |
| 153 | +
|
| 154 | + :param path: Path to the file. |
| 155 | + :return: An instance of the model. |
| 156 | + """ |
| 157 | + logger.debug("Loading from file... {}", path) |
| 158 | + |
| 159 | + if not os.path.exists(path): |
| 160 | + raise FileNotFoundError(f"File not found: {path}") |
| 161 | + elif not os.path.isfile(path): |
| 162 | + raise ValueError(f"Path is not a file: {path}") |
| 163 | + |
| 164 | + extension = path.split(".")[-1].upper() |
| 165 | + |
| 166 | + if extension not in SerializableFileType.__members__: |
| 167 | + raise ValueError( |
| 168 | + f"Unsupported file extension: {extension}. " |
| 169 | + f"Expected one of {', '.join(SerializableFileType.__members__)}) " |
| 170 | + f"for {path}" |
| 171 | + ) |
| 172 | + |
| 173 | + type_ = SerializableFileType[extension] |
| 174 | + |
| 175 | + with open(path, "r") as file: |
| 176 | + data = file.read() |
| 177 | + |
| 178 | + if type_ == SerializableFileType.YAML: |
| 179 | + obj = cls.from_yaml(data) |
| 180 | + elif type_ == SerializableFileType.JSON: |
| 181 | + obj = cls.from_json(data) |
| 182 | + else: |
| 183 | + raise ValueError(f"Unsupported file format: {type_}") |
| 184 | + |
| 185 | + return obj |
0 commit comments