33from copy import deepcopy
44from pathlib import Path
55from pprint import pformat
6- from typing import Any , Dict , List , Literal , Optional , Type , Union
6+ from typing import Any , Dict , List , Optional , Type , Union
77
88import yaml as _yaml
99from camel_converter import to_snake as _to_snake
1010from munch import Munch as _Munch
1111from munch import munchify as _munchify
12- from pydantic import BaseModel , ConfigDict , create_model
12+ from pydantic import BaseModel , ConfigDict , Field , create_model , model_validator
1313
1414
1515logging .basicConfig (format = '%(asctime)-15s \t %(levelname)-8s \t %(name)-8s \t %(message)s' )
@@ -50,7 +50,7 @@ def init_config(
5050 raise_error_non_identifiers: raise error if config section name is not a valid identifier
5151 validate_data_types: raise error if data types in config are not the same as default (makes sense only if merge is enabled)
5252 allow_extra_sections: raise error if there are extra sections in config (may break if section name formatting is enabled)
53- warn_extra_sections: warn about extra keys and values on the first level
53+ warn_extra_sections: warn about extra keys and values
5454 """
5555
5656 def _merge_configs (
@@ -70,13 +70,17 @@ def _merge_configs(
7070 f_section = _sanitize_section (section )
7171 sections .append (f_section )
7272 if f_section not in _raw_data :
73+ if isinstance (entry , Dict ):
74+ entry = _sanitize_keys (entry )
7375 _raw_data [f_section ] = entry
7476 logger .debug (f'section `{ "." .join (sections )} ` with value `{ entry } ` taken from { default_config } ' )
7577 else :
7678 logger .debug (f'section `{ "." .join (sections )} ` already exists in { config } , skipping' )
7779 elif isinstance (entry , Dict ):
7880 sections .append (section )
7981 _merge_configs (_raw_data [section ], entry , sections )
82+ f_section = _sanitize_section (section )
83+ _raw_data [f_section ] = _raw_data .pop (section , None )
8084 # TODO: add support for merging lists
8185 else :
8286 f_section = _sanitize_section (section )
@@ -109,25 +113,73 @@ def _pop_ignored_keys(data: ConfigType) -> ConfigType:
109113 _pop_ignored_keys (entry )
110114 return data
111115
112- def _model_from_dict (name : str , data : Dict [str , Any ], extra : bool ) -> Type [BaseModel ]:
116+ def _sanitize_keys (data : ConfigType ) -> ConfigType :
117+ for key , entry in data .copy ().items ():
118+ if isinstance (entry , Dict ):
119+ _sanitize_keys (entry )
120+ else :
121+ data [_sanitize_section (key )] = data .pop (key , None )
122+ return data
123+
124+ def _pop_nested (d : Dict [str , Any ], dotted_key : str , default : Any = None ) -> Any :
125+ keys = dotted_key .split ('.' )
126+ current = d
127+
128+ for k in keys [:- 1 ]:
129+ if not isinstance (current , dict ) or k not in current :
130+ return default
131+ current = current [k ]
132+
133+ return current .pop (keys [- 1 ], default )
134+
135+ # https://stackoverflow.com/questions/73958753/return-all-extra-passed-to-pydantic-model
136+ class NewBase (BaseModel ):
137+ model_config = ConfigDict (strict = True , extra = 'allow' if allow_extra_sections else 'forbid' )
138+ extra : Dict [str , Any ] = Field (default = {}, exclude = True )
139+
140+ @model_validator (mode = 'before' )
141+ @classmethod
142+ def validator (cls , values : Any ) -> Any :
143+ if cls .model_config .get ('extra' ) == 'allow' :
144+ extra , valid = {}, {}
145+ for key , value in values .items ():
146+ if key in cls .model_fields :
147+ valid [key ] = value
148+ else :
149+ extra [key ] = value
150+ valid ['extra' ] = extra
151+ return valid
152+ return values
153+
154+ @property
155+ def extra_flat (self ) -> Any :
156+ extra_flat = {** self .extra }
157+ for name , value in self :
158+ if isinstance (value , NewBase ):
159+ data = {f'{ name } .{ k } ' : v for k , v in value .extra_flat .items ()}
160+ extra_flat .update (data )
161+ return extra_flat
162+
163+ def _model_from_dict (name : str , data : Dict [str , Any ]) -> Type [BaseModel ]:
113164 fields : Dict [Any , Any ] = {}
114165 for section , entry in data .items ():
166+ section = _sanitize_section (section )
115167 if isinstance (entry , Dict ):
116- nested_model = _model_from_dict (section , entry , extra )
168+ nested_model = _model_from_dict (section , entry )
117169 fields [section ] = (nested_model , entry )
118170 elif isinstance (entry , list ) and entry :
119171 first_item = entry [0 ]
120172 if isinstance (first_item , Dict ):
121- nested_model = _model_from_dict (f'{ section .capitalize ()} Item' , first_item , extra )
173+ nested_model = _model_from_dict (f'{ section .capitalize ()} Item' , first_item )
122174 fields [section ] = (List [nested_model ], entry ) # type: ignore
123175 else :
124176 fields [section ] = (List [type (first_item )], entry ) # type: ignore
125177 elif isinstance (entry , list ):
126178 fields [section ] = (List [Any ], entry )
127179 else :
128180 fields [section ] = (type (entry ), entry )
129- extra_value : Literal [ 'allow' , 'forbid' ] = 'allow' if extra else 'forbid'
130- return create_model ( name , ** fields , __config__ = ConfigDict ( strict = True , extra = extra_value ))
181+ model = create_model ( name , ** fields , __base__ = NewBase )
182+ return model
131183
132184 try :
133185 with open (Path (config )) as fstream :
@@ -157,25 +209,20 @@ def _model_from_dict(name: str, data: Dict[str, Any], extra: bool) -> Type[BaseM
157209 # create copy for logging (only overwritten fields)
158210 _raw_data_copy = deepcopy (_raw_data )
159211 _merge_configs (_raw_data , _default_raw_data )
212+ logger .debug (f'\n \n Resulting config after merge:\n \n { pformat (_raw_data )} ' )
160213 if validate_data_types :
161- ConfigModel = _model_from_dict ('ConfigModel' , _default_raw_data , allow_extra_sections )
214+ ConfigModel = _model_from_dict ('ConfigModel' , _default_raw_data )
162215 try :
163216 validated_raw_data = ConfigModel .model_validate (_raw_data )
164- if validated_raw_data .model_extra :
165- extra_sections = validated_raw_data .model_extra
166- # remove formatted sections from extra
167- for k in _default_raw_data :
168- sk = _sanitize_section (k )
169- if sk in extra_sections :
170- extra_sections .pop (sk )
171- if extra_sections and warn_extra_sections :
217+ if extra_sections := validated_raw_data .extra_flat : # type: ignore
218+ if warn_extra_sections :
172219 logger .warning (
173220 f'\n \n The following extra sections will be ignored:\n \n { pformat (extra_sections )} '
174221 )
175222 # remove extra sections from resulting config
176223 for k in extra_sections :
177- _raw_data_copy . pop ( k , None )
178- _raw_data . pop ( k , None )
224+ _pop_nested ( _raw_data_copy , k )
225+ _pop_nested ( _raw_data , k )
179226 except Exception as e :
180227 err_msg = f'Failed validating config file: { e !r} '
181228 logger .error (err_msg )
@@ -184,9 +231,10 @@ def _model_from_dict(name: str, data: Dict[str, Any], extra: bool) -> Type[BaseM
184231 for k in _raw_data_copy .copy ():
185232 sk = _sanitize_section (k )
186233 if sk in _raw_data :
187- _raw_data_copy [sk ] = _raw_data [sk ]
188234 _raw_data_copy .pop (k , None )
189- logger .info (f'\n \n The following sections were overwritten:\n \n { pformat (_raw_data_copy )} ' )
235+ _raw_data_copy [sk ] = _raw_data [sk ]
236+ if _raw_data_copy :
237+ logger .info (f'\n \n The following sections were overwritten:\n \n { pformat (_raw_data_copy )} ' )
190238 try :
191239 raw_data = _munchify (_raw_data )
192240 logger .debug (f'\n \n Resulting config:\n \n { pformat (raw_data )} ' )
0 commit comments