|
1 | 1 | import hashlib |
2 | 2 | import logging |
3 | 3 | import re |
4 | | -from pathlib import Path |
5 | 4 | import coloredlogs |
6 | 5 | import yaml |
7 | 6 | from os.path import join, isfile |
8 | | -from yamlcore import CoreLoader, CoreDumper |
| 7 | +from yamlcore import CoreDumper |
9 | 8 |
|
10 | 9 |
|
11 | 10 | # Hex integers |
@@ -52,44 +51,87 @@ def hash_yaml(section_to_hash): |
52 | 51 | return hash_digest |
53 | 52 |
|
54 | 53 |
|
55 | | -def patch_config(base_config, patch): |
56 | | - # Merge configs. |
57 | | - def _recursive_update(base, new): |
58 | | - for k, v in new.items(): |
59 | | - if isinstance(v, dict): |
60 | | - base[k] = _recursive_update(base.get(k, {}), v) |
61 | | - elif isinstance(v, list): |
62 | | - # Append |
63 | | - base[k] = base.get(k, []) + v |
64 | | - else: |
65 | | - base[k] = v |
66 | | - return base |
67 | | - |
68 | | - if issubclass(type(patch), Path): |
69 | | - with open(patch, "r") as f: |
70 | | - patch = yaml.load(f, Loader=CoreLoader) |
| 54 | +def patch_config(logger, base_config, patch): |
71 | 55 | if not patch: |
72 | 56 | # Empty patch, possibly an empty file or one with all comments |
73 | 57 | return base_config |
74 | | - for key, value in patch.items(): |
75 | | - # Check if the key already exists in the base_config |
76 | | - if key in base_config: |
77 | | - # If the value is a dictionary, update subfields |
78 | | - if isinstance(value, dict): |
79 | | - # Recursive update to handle nested dictionaries |
80 | | - base_config[key] = _recursive_update(base_config.get(key, {}), value) |
81 | | - elif isinstance(value, list): |
82 | | - # Merge lists |
83 | | - seen = set() |
84 | | - combined = base_config[key] + value |
85 | | - base_config[key] = [x for x in combined if not (x in seen or seen.add(x))] |
86 | | - else: |
87 | | - # Replace the base value with the incoming value |
88 | | - base_config[key] = value |
89 | | - else: |
90 | | - # New key, add all data directly |
91 | | - base_config[key] = value |
92 | | - return base_config |
| 58 | + |
| 59 | + # Merge configs. |
| 60 | + def _recursive_update(base, new, config_option): |
| 61 | + if base is None: |
| 62 | + return new |
| 63 | + if new is None: |
| 64 | + return base |
| 65 | + |
| 66 | + assert type(base) is type(new) |
| 67 | + |
| 68 | + if hasattr(base, "merge"): |
| 69 | + return base.merge(new) |
| 70 | + |
| 71 | + if hasattr(base, "model_fields_set"): |
| 72 | + result = dict() |
| 73 | + for base_key in base.model_fields_set: |
| 74 | + result[base_key] = getattr(base, base_key) |
| 75 | + if base.model_extra is not None: |
| 76 | + for base_key, base_value in base.model_extra.items(): |
| 77 | + result[base_key] = base_value |
| 78 | + for new_key in new.model_fields_set: |
| 79 | + new_value = getattr(new, new_key) |
| 80 | + if new_key in result: |
| 81 | + result[new_key] = _recursive_update( |
| 82 | + result[new_key], |
| 83 | + new_value, |
| 84 | + f"{config_option}.{new_key}" if config_option else new_key, |
| 85 | + ) |
| 86 | + else: |
| 87 | + result[new_key] = new_value |
| 88 | + if new.model_extra is not None: |
| 89 | + for new_key, new_value in new.model_extra.items(): |
| 90 | + if new_key in result: |
| 91 | + result[new_key] = _recursive_update( |
| 92 | + result[new_key], |
| 93 | + new_value, |
| 94 | + f"{config_option}.{new_key}" if config_option else new_key, |
| 95 | + ) |
| 96 | + else: |
| 97 | + result[new_key] = new_value |
| 98 | + return type(base)(**result) |
| 99 | + |
| 100 | + if isinstance(base, list): |
| 101 | + return base + new |
| 102 | + |
| 103 | + if isinstance(base, dict): |
| 104 | + result = dict() |
| 105 | + for key, base_value in base.items(): |
| 106 | + if key in new: |
| 107 | + new_value = new[key] |
| 108 | + result[key] = _recursive_update( |
| 109 | + base_value, |
| 110 | + new_value, |
| 111 | + f"{config_option}.{key}" if config_option else key, |
| 112 | + ) |
| 113 | + else: |
| 114 | + result[key] = base_value |
| 115 | + for new_key, new_value in new.items(): |
| 116 | + if new_key not in base: |
| 117 | + result[new_key] = new_value |
| 118 | + return result |
| 119 | + |
| 120 | + if base == new: |
| 121 | + return base |
| 122 | + |
| 123 | + base_str = yaml.dump(base).strip().removesuffix("...").strip() |
| 124 | + new_str = yaml.dump(new).strip().removesuffix("...").strip() |
| 125 | + change_str = ( |
| 126 | + f"\n```\n{base_str}\n```↓\n```\n{new_str}\n```" |
| 127 | + if "\n" in base_str + new_str |
| 128 | + else f"`{base_str}` → `{new_str}`" |
| 129 | + ) |
| 130 | + logger.warning(f"patch conflict: {config_option}: {change_str}") |
| 131 | + |
| 132 | + return new |
| 133 | + |
| 134 | + return _recursive_update(base_config, patch, None) |
93 | 135 |
|
94 | 136 |
|
95 | 137 | class PathHighlightingFormatter(coloredlogs.ColoredFormatter): |
|
0 commit comments