@@ -17,6 +17,9 @@ class ExampleModule(Module):
1717from __future__ import annotations
1818from dataclasses import dataclass , fields
1919import typeguard
20+ from torch import nn
21+ from typing import Generic , TypeVar , Type
22+ import inspect
2023
2124
2225@dataclass
@@ -51,3 +54,37 @@ def _validate_types(self) -> None:
5154
5255 def __post_init__ (self ) -> None :
5356 self ._validate_types ()
57+
58+
59+ ConfigType = TypeVar ("ConfigType" , bound = ModelConfiguration )
60+ ModuleType = TypeVar ("ModuleType" , bound = nn .Module )
61+
62+
63+ @dataclass
64+ class ModuleFactoryV1 (Generic [ConfigType , ModuleType ]):
65+ """
66+ Dataclass for a combination of a Subassembly/Part and the corresponding configuration.
67+ Also provides a function to construct the corresponding object through this dataclass
68+ """
69+
70+ module_class : Type [ModuleType ]
71+ cfg : ConfigType
72+
73+ def __call__ (self ) -> ModuleType :
74+ """Constructs an instance of the given module class"""
75+ return self .module_class (self .cfg )
76+
77+ def __post_init__ (self ) -> None :
78+ # Check typing of module_class and cfg, i.e. make sure that "self.module_class(self.cfg)" is a valid call.
79+ parameters = inspect .signature (self .module_class ).parameters .values ()
80+ assert len (parameters ) >= 1
81+ parameter_iter = iter (parameters )
82+
83+ # 1. Check that the first parameter is either not annotated or the annotation matches the type of self.cfg
84+ cfg_parameter = next (parameter_iter )
85+ if cfg_parameter .annotation is not inspect .Parameter .empty :
86+ typeguard .check_type (self .cfg , cfg_parameter .annotation )
87+
88+ # 2. Check that all other parameters have default values
89+ for parameter in parameter_iter :
90+ assert parameter .default is not inspect .Parameter .empty
0 commit comments