33from copy import deepcopy
44from pathlib import Path
55from pprint import pformat
6- from typing import Any , Dict , List , Optional , Type , Union
6+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
77
88import yaml as _yaml
99from camel_converter import to_snake as _to_snake
@@ -37,6 +37,8 @@ def init_config(
3737 validate_data_types : bool = True ,
3838 allow_extra_sections : bool = True ,
3939 warn_extra_sections : bool = True ,
40+ _generate_stub : bool = False ,
41+ _stub_variable_name : str = 'config' ,
4042) -> PyyaConfig :
4143 """Initialize attribute-stylish configuration from YAML file.
4244
@@ -154,24 +156,85 @@ def extra_flat(self) -> Any:
154156 extra_flat .update (data )
155157 return extra_flat
156158
157- def _model_from_dict (name : str , data : Dict [str , Any ]) -> Type [BaseModel ]:
159+ def _model_and_stub_from_dict (
160+ name : str , data : Dict [str , Any ], path : Optional [List [str ]] = None
161+ ) -> Tuple [Type [ExtraBase ], str ]:
158162 fields : Dict [Any , Any ] = {}
163+ if path is None :
164+ path = []
165+ class_name = '' .join (part .capitalize () if i > 0 else part for i , part in enumerate (path + [name ]))
166+ stub_lines = [f'class { class_name } :' ]
167+ nested_stubs = []
168+ py_type : Any
159169 for section , entry in data .items ():
160170 if isinstance (entry , Dict ):
161- nested_model = _model_from_dict (section , entry )
171+ nested_model , nested_stub = _model_and_stub_from_dict (section , entry , path + [name ])
172+ if not keyword .iskeyword (section ) and section .isidentifier ():
173+ stub_lines .append (f' { section } : { class_name + section .capitalize ()} ' )
174+ nested_stubs .append (nested_stub )
162175 fields [section ] = (nested_model , entry )
163176 elif isinstance (entry , list ) and entry :
164177 first_item = entry [0 ]
165178 if isinstance (first_item , Dict ):
166- nested_model = _model_from_dict (f'{ section .capitalize ()} Item' , first_item )
179+ nested_model , nested_stub = _model_and_stub_from_dict (
180+ f'{ section .capitalize ()} _item' , first_item , path + [name ]
181+ )
182+ if not keyword .iskeyword (section ) and section .isidentifier ():
183+ stub_lines .append (f' { section } : List[{ class_name + section .capitalize ()} _item]' )
184+ nested_stubs .append (nested_stub )
167185 fields [section ] = (List [nested_model ], entry ) # type: ignore
168186 else :
169- fields [section ] = (List [type (first_item )], entry ) # type: ignore
187+ py_type = type (first_item )
188+ if not keyword .iskeyword (section ) and section .isidentifier ():
189+ stub_lines .append (f' { section } : List[{ py_type .__name__ } ]' )
190+ fields [section ] = (List [py_type ], entry )
170191 elif isinstance (entry , list ):
192+ if not keyword .iskeyword (section ) and section .isidentifier ():
193+ stub_lines .append (f' { section } : List[Any]' )
171194 fields [section ] = (List [Any ], entry )
172195 else :
173- fields [section ] = (type (entry ), entry )
174- return create_model (name , ** fields , __base__ = ExtraBase )
196+ py_type = type (entry )
197+ if not keyword .iskeyword (section ) and section .isidentifier ():
198+ stub_lines .append (f' { section } : { py_type .__name__ } ' )
199+ fields [section ] = (py_type , entry )
200+ stub_code = '\n \n ' .join (nested_stubs + ['\n ' .join (stub_lines )]).replace ('-' , '_' )
201+ return create_model (name , ** fields , __base__ = ExtraBase ), stub_code
202+
203+ def _get_default_raw_data () -> ConfigType :
204+ try :
205+ try :
206+ with open (Path (default_config )) as fstream :
207+ _default_raw_data : Optional [ConfigType ] = _yaml .safe_load (fstream )
208+ except _yaml .YAMLError as e :
209+ err_msg = f'{ default_config } file is corrupted: { e } '
210+ logger .error (err_msg )
211+ raise PyyaError (err_msg ) from None
212+ if _default_raw_data is None :
213+ raise FileNotFoundError ()
214+ except FileNotFoundError as e :
215+ logger .error (e )
216+ raise PyyaError (f'{ default_config } file is missing or empty' ) from None
217+ _default_raw_data = _sanitize_keys (_default_raw_data )
218+ return _default_raw_data
219+
220+ if _generate_stub :
221+ output_file = Path (config )
222+ if output_file .exists ():
223+ err_msg = f'{ output_file } already exists'
224+ logger .error (err_msg )
225+ raise PyyaError (err_msg )
226+ _default_raw_data = _get_default_raw_data ()
227+ _ , stub = _model_and_stub_from_dict ('Config' , _default_raw_data )
228+ stub_full = (
229+ f'# { output_file } was autogenerated with pyya CLI tool, see `pyya -h`\n from typing import Any, List\n \n '
230+ f'{ stub } \n \n '
231+ '# for type hints to work the variable name created with pyya.init_config\n '
232+ '# should have the same name (e.g. config = pyya.init_config())\n '
233+ f'{ _stub_variable_name } : Config\n '
234+ )
235+ output_file .write_text (stub_full )
236+ logger .info (f'{ output_file } created' )
237+ return PyyaConfig ()
175238
176239 try :
177240 with open (Path (config )) as fstream :
@@ -194,26 +257,13 @@ def _model_from_dict(name: str, data: Dict[str, Any]) -> Type[BaseModel]:
194257 err_msg = f'Failed parsing `sections_ignored_on_merge`: { e !r} '
195258 logger .error (err_msg )
196259 raise PyyaError (err_msg ) from None
197- try :
198- try :
199- with open (Path (default_config )) as fstream :
200- _default_raw_data : Optional [ConfigType ] = _yaml .safe_load (fstream )
201- _default_raw_data = _sanitize_keys (_default_raw_data )
202- except _yaml .YAMLError as e :
203- err_msg = f'{ default_config } file is corrupted: { e } '
204- logger .error (err_msg )
205- raise PyyaError (err_msg ) from None
206- if _default_raw_data is None :
207- raise FileNotFoundError ()
208- except FileNotFoundError as e :
209- logger .error (e )
210- raise PyyaError (f'{ default_config } file is missing or empty' ) from None
260+ _default_raw_data = _get_default_raw_data ()
211261 # create copy for logging (only overwritten fields)
212262 _raw_data_copy = deepcopy (_raw_data )
213263 _merge_configs (_raw_data , _default_raw_data )
214264 logger .debug (f'Resulting config after merge:\n { pformat (_raw_data )} ' )
215265 if validate_data_types :
216- ConfigModel = _model_from_dict ('ConfigModel' , _default_raw_data )
266+ ConfigModel , _ = _model_and_stub_from_dict ('ConfigModel' , _default_raw_data )
217267 try :
218268 validated_raw_data = ConfigModel .model_validate (_raw_data )
219269 if extra_sections := validated_raw_data .extra_flat : # type: ignore
@@ -231,7 +281,7 @@ def _model_from_dict(name: str, data: Dict[str, Any]) -> Type[BaseModel]:
231281 logger .info (f'The following sections were overwritten:\n { pformat (_raw_data_copy )} ' )
232282 try :
233283 logger .debug (f'Resulting config:\n { pformat (_raw_data )} ' )
234- return _munchify (_raw_data )
284+ return PyyaConfig ( _munchify (_raw_data ) )
235285 except Exception as e :
236286 err_msg = f'Failed parsing config file: { e !r} '
237287 logger .error (err_msg )
0 commit comments