Skip to content

Commit 9194875

Browse files
Atticus1806SimBe195albertz
authored
ModuleFactory (#21)
* module factory --------- Co-authored-by: SimBe195 <[email protected]> Co-authored-by: Albert Zeyer <[email protected]>
1 parent caf4e29 commit 9194875

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

i6_models/config.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class ExampleModule(Module):
1717
from __future__ import annotations
1818
from dataclasses import dataclass, fields
1919
import typeguard
20+
from torch import nn
21+
from typing import Generic, TypeVar, Type
22+
import inspect
2023

2124

2225
@dataclass
@@ -51,3 +54,37 @@ def _validate_types(self) -> None:
5154

5255
def __post_init__(self) -> None:
5356
self._validate_types()
57+
58+
59+
ConfigType = TypeVar("ConfigType", bound=ModelConfiguration)
60+
ModuleType = TypeVar("ModuleType", bound=nn.Module)
61+
62+
63+
@dataclass
64+
class ModuleFactoryV1(Generic[ConfigType, ModuleType]):
65+
"""
66+
Dataclass for a combination of a Subassembly/Part and the corresponding configuration.
67+
Also provides a function to construct the corresponding object through this dataclass
68+
"""
69+
70+
module_class: Type[ModuleType]
71+
cfg: ConfigType
72+
73+
def __call__(self) -> ModuleType:
74+
"""Constructs an instance of the given module class"""
75+
return self.module_class(self.cfg)
76+
77+
def __post_init__(self) -> None:
78+
# Check typing of module_class and cfg, i.e. make sure that "self.module_class(self.cfg)" is a valid call.
79+
parameters = inspect.signature(self.module_class).parameters.values()
80+
assert len(parameters) >= 1
81+
parameter_iter = iter(parameters)
82+
83+
# 1. Check that the first parameter is either not annotated or the annotation matches the type of self.cfg
84+
cfg_parameter = next(parameter_iter)
85+
if cfg_parameter.annotation is not inspect.Parameter.empty:
86+
typeguard.check_type(self.cfg, cfg_parameter.annotation)
87+
88+
# 2. Check that all other parameters have default values
89+
for parameter in parameter_iter:
90+
assert parameter.default is not inspect.Parameter.empty

tests/test_config.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass
22
import pytest
33

4-
from i6_models.config import ModelConfiguration
4+
from i6_models.config import ModelConfiguration, ModuleFactoryV1
5+
from torch import nn
56

67

78
def test_simple_configuration():
@@ -58,3 +59,31 @@ class TestConfiguration(ModelConfiguration):
5859
TestConfiguration(num_layers=2, hidden_dim=1)
5960
with pytest.raises(TypeCheckError):
6061
TestConfiguration(num_layers=2.0, hidden_dim="One")
62+
63+
64+
def test_module_factory():
65+
@dataclass
66+
class TestConfiguration(ModelConfiguration):
67+
param: int = 4
68+
69+
@dataclass
70+
class TestConfiguration2(ModelConfiguration):
71+
param: float = 4.3
72+
73+
class TestModule(nn.Module):
74+
def __init__(self, cfg: TestConfiguration):
75+
super().__init__()
76+
self.cfg = cfg
77+
78+
def forward(self):
79+
pass
80+
81+
factory = ModuleFactoryV1(module_class=TestModule, cfg=TestConfiguration())
82+
obj = factory()
83+
assert type(obj) == TestModule
84+
obj2 = factory()
85+
assert obj != obj2
86+
from typeguard import TypeCheckError
87+
88+
with pytest.raises(TypeCheckError):
89+
ModuleFactoryV1(module_class=TestModule, cfg=TestConfiguration2())

0 commit comments

Comments
 (0)