|
| 1 | +import warnings |
1 | 2 | from dataclasses import dataclass |
2 | 3 | from enum import Enum |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +from radiosim.ppdisks.config import Parser, Variables |
3 | 7 |
|
4 | 8 | __all__ = [ |
5 | 9 | "FargoOptionConfig", |
|
8 | 12 |
|
9 | 13 |
|
10 | 14 | class OptionType(Enum): |
11 | | - VARIABLE = 1 |
12 | | - LIST = 2 |
13 | | - OPTION = 3 |
| 15 | + VARIABLE = " = " |
| 16 | + LIST = " := " |
| 17 | + OPTION = " += " |
| 18 | + |
| 19 | + def __init__(self, seperator: str): |
| 20 | + self.seperator: str = seperator |
| 21 | + |
| 22 | + def split_line(self, line: str) -> tuple[str]: |
| 23 | + if line.endswith("\n"): |
| 24 | + line = line.replace("\n", "") |
| 25 | + |
| 26 | + key, value = line.split(self.seperator) |
| 27 | + |
| 28 | + @classmethod |
| 29 | + def from_value(cls, value: str) -> "OptionType": |
| 30 | + for option_type in cls.__members__.values(): |
| 31 | + if option_type.seperator in value: |
| 32 | + return option_type |
| 33 | + |
| 34 | + raise TypeError("The given value is no valid OptionType!") |
14 | 35 |
|
15 | 36 |
|
16 | 37 | @dataclass |
17 | 38 | class FargoOptionEntry: |
18 | 39 | value: object | None |
19 | 40 | option_type: OptionType |
20 | 41 | enabled: bool = True |
21 | | - cuda_only: bool = False |
22 | 42 |
|
23 | 43 | def enable(self) -> None: |
24 | 44 | self.enabled = True |
25 | 45 |
|
26 | 46 | def disable(self) -> None: |
27 | 47 | self.enabled = False |
28 | 48 |
|
| 49 | + def get_line(self, key: str) -> str: |
| 50 | + match self.option_type: |
| 51 | + case OptionType.VARIABLE: |
| 52 | + return f"{key} = {self.value}" |
| 53 | + case OptionType.LIST: |
| 54 | + if not isinstance(self.value, list): |
| 55 | + raise ValueError( |
| 56 | + "The OptionType LIST has to contain a value with type list!" |
| 57 | + ) |
| 58 | + return f"{key} := {' '.join(self.value)}" |
| 59 | + case OptionType.OPTION: |
| 60 | + if self.value is None: |
| 61 | + return f"FARGO_OPT += -D{key}" |
| 62 | + else: |
| 63 | + return f"FARGO_OPT += -D{key}={self.value}" |
| 64 | + |
29 | 65 | @classmethod |
30 | | - def option( |
31 | | - enabled: bool, value: object | None = None, cuda_only: bool = False |
32 | | - ) -> "FargoOptionEntry": |
| 66 | + def option(enabled: bool, value: object | None = None) -> "FargoOptionEntry": |
33 | 67 | return FargoOptionEntry( |
34 | 68 | value=value, |
35 | 69 | option_type=OptionType.OPTION, |
36 | 70 | enabled=enabled, |
37 | | - cuda_only=cuda_only, |
38 | 71 | ) |
39 | 72 |
|
40 | 73 |
|
41 | 74 | class FargoOptionConfig: |
42 | | - def __init__(self): |
43 | | - self._parameters = { |
| 75 | + def __init__(self, setup: str, autosave: bool = False): |
| 76 | + self._path: Path = Variables.get("FARGO_ROOT") / f"setups/{setup}/{setup}.opt" |
| 77 | + self._autosave: bool = autosave |
| 78 | + |
| 79 | + if not self._path.exists(): |
| 80 | + raise NameError(f"The given setup '{setup}' does not exist!") |
| 81 | + |
| 82 | + # initialize default parameters |
| 83 | + self._parameters: dict = { |
44 | 84 | "fluids": { |
45 | 85 | # This parameter automatically implies the list variable FLUIDS |
46 | 86 | # and the option NFLUIDS |
@@ -102,14 +142,143 @@ def __init__(self): |
102 | 142 | "LONGSUMMARY": FargoOptionEntry.option(enabled=True), |
103 | 143 | }, |
104 | 144 | "cuda_blocks": { |
105 | | - "BLOCK_X": FargoOptionEntry.option( |
106 | | - value=16, enabled=True, cuda_only=True |
107 | | - ), |
108 | | - "BLOCK_Y": FargoOptionEntry.option( |
109 | | - value=16, enabled=True, cuda_only=True |
110 | | - ), |
111 | | - "BLOCK_Z": FargoOptionEntry.option( |
112 | | - value=1, enabled=True, cuda_only=True |
113 | | - ), |
| 145 | + "BLOCK_X": FargoOptionEntry.option(value=16, enabled=True), |
| 146 | + "BLOCK_Y": FargoOptionEntry.option(value=16, enabled=True), |
| 147 | + "BLOCK_Z": FargoOptionEntry.option(value=1, enabled=True), |
114 | 148 | }, |
115 | 149 | } |
| 150 | + |
| 151 | + def disable_all(self): |
| 152 | + for _category, category_dict in self._parameters.items(): |
| 153 | + for _key, value in category_dict.items(): |
| 154 | + value.disable() |
| 155 | + |
| 156 | + def save(self): |
| 157 | + with open(self._path) as file: |
| 158 | + old_content = file.read() |
| 159 | + |
| 160 | + with open(self._path, "w") as file: |
| 161 | + try: |
| 162 | + lines = [] |
| 163 | + for category, category_dict in self._parameters.items(): |
| 164 | + if len(category_dict > 0): |
| 165 | + lines.append("\n") |
| 166 | + lines.append(f"# [{category}]\n") |
| 167 | + lines.append("\n") |
| 168 | + |
| 169 | + if category == "cuda_blocks": |
| 170 | + lines.append("ifeq (${GPU}, 1)\n") |
| 171 | + |
| 172 | + for key, value in category_dict.items(): |
| 173 | + if not value.enabled: |
| 174 | + continue |
| 175 | + |
| 176 | + if key == "NFLUIDS": |
| 177 | + lines.append( |
| 178 | + FargoOptionEntry( |
| 179 | + value=list(range(value.value)), |
| 180 | + option_type=OptionType.LIST, |
| 181 | + enabled=True, |
| 182 | + ).get_line(key="FLUIDS") |
| 183 | + ) |
| 184 | + lines.append(value.get_line(key=key)) |
| 185 | + lines.append( |
| 186 | + FargoOptionEntry.option( |
| 187 | + value="${NFLUIDS}", enabled=True |
| 188 | + ).get_line(key="NFLUIDS") |
| 189 | + ) |
| 190 | + else: |
| 191 | + lines.append(value.get_line(key=key)) |
| 192 | + |
| 193 | + if category == "cuda_blocks": |
| 194 | + lines.append("endif") |
| 195 | + |
| 196 | + file.writelines(lines) |
| 197 | + |
| 198 | + except Exception as e: |
| 199 | + warnings.warn( |
| 200 | + "An error occured while saving. Rolling back configuration files.", |
| 201 | + stacklevel=1, |
| 202 | + ) |
| 203 | + file.write(old_content) |
| 204 | + raise e |
| 205 | + |
| 206 | + def load(self): |
| 207 | + self.disable_all() |
| 208 | + with open(self._path) as file: |
| 209 | + lines = file.readlines() |
| 210 | + |
| 211 | + current_category = None |
| 212 | + for line in lines: |
| 213 | + if line.strip() == "": |
| 214 | + continue |
| 215 | + |
| 216 | + if line.startswith("# ") and "[" in line and "]" in line: |
| 217 | + current_category = ( |
| 218 | + line.removeprefix("# ").split("[")[1].split("]")[0] |
| 219 | + ) |
| 220 | + continue |
| 221 | + |
| 222 | + if current_category is None: |
| 223 | + raise ValueError( |
| 224 | + "The sections of the .opt file are not valid! " |
| 225 | + "Every variable must be inside a catgeory!" |
| 226 | + ) |
| 227 | + |
| 228 | + option_type = OptionType.from_value(line) |
| 229 | + key, value = option_type.split_line(line=line) |
| 230 | + entry = self[key] |
| 231 | + if entry.option_type != option_type: |
| 232 | + raise TypeError( |
| 233 | + f"The variable '{key}' has to have the type " |
| 234 | + f"{entry.option_type}! (In config: {option_type})" |
| 235 | + ) |
| 236 | + |
| 237 | + entry.value = Parser().parse(value) |
| 238 | + entry.enable() |
| 239 | + |
| 240 | + def __getitem__(self, key: str) -> FargoOptionEntry | dict: |
| 241 | + key_components = key.split(".") |
| 242 | + |
| 243 | + match len(key_components): |
| 244 | + case 1: |
| 245 | + return self._parameters[key_components[0]] |
| 246 | + case 2: |
| 247 | + return self._parameters[key_components[0]][key_components[1]] |
| 248 | + case _: |
| 249 | + if len(key_components) > 2: |
| 250 | + raise KeyError( |
| 251 | + "The maximum depth of a config key is 2 (catgeory -> entry)!" |
| 252 | + ) |
| 253 | + |
| 254 | + def __setitem__(self, key: str, value: object) -> None: |
| 255 | + key_components = key.split(".") |
| 256 | + |
| 257 | + match len(key_components): |
| 258 | + case 1: |
| 259 | + if isinstance(value, dict): |
| 260 | + self._parameters[key_components[0]] = value |
| 261 | + return None |
| 262 | + else: |
| 263 | + raise TypeError("Values at root level must either be a dict!") |
| 264 | + case 2: |
| 265 | + if isinstance(value, FargoOptionEntry): |
| 266 | + self._parameters[key_components[0]][key_components[1]] = value |
| 267 | + elif isinstance( |
| 268 | + self._parameters[key_components[0]][key_components[1]], |
| 269 | + FargoOptionEntry, |
| 270 | + ): |
| 271 | + self._parameters[key_components[0]][key_components[1]].value = value |
| 272 | + else: |
| 273 | + raise TypeError( |
| 274 | + "This key does not point to a valid entry! Enter an instance " |
| 275 | + "of a 'FargoOptionEntry'" |
| 276 | + ) |
| 277 | + case _: |
| 278 | + if len(key_components) > 2: |
| 279 | + raise KeyError( |
| 280 | + "The maximum depth of a config key is 2 (catgeory -> entry)!" |
| 281 | + ) |
| 282 | + |
| 283 | + if self._autosave: |
| 284 | + self.save() |
0 commit comments