Skip to content

Commit 407cf43

Browse files
Atticus1806albertzvietingSimBe195
authored
Add ModelConfiguration class (#3)
Co-authored-by: Albert Zeyer <[email protected]> Co-authored-by: vieting <[email protected]> Co-authored-by: SimBe195 <[email protected]>
1 parent e96a497 commit 407cf43

File tree

4 files changed

+124
-0
lines changed

4 files changed

+124
-0
lines changed

.github/workflows/model_tests.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: model_tests
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
branches:
8+
- main
9+
jobs:
10+
test-jobs:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v2
14+
with:
15+
repository: "rwth-i6/i6_models"
16+
path: ""
17+
- uses: actions/setup-python@v2
18+
with:
19+
python-version: 3.7
20+
cache: 'pip'
21+
- run: |
22+
pip install pytest
23+
pip install -r requirements.txt
24+
- name: Test Models
25+
run: |
26+
python -m pytest tests

i6_models/config.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Provides the base class for configurations, model configurations can be derived from this base class e.g:
3+
4+
@dataclass
5+
class ExampleConfig(ModelConfiguration):
6+
hidden_dim: int = 256
7+
name: str = "Example Configuration"
8+
9+
class ExampleModule(Module):
10+
__init__(self, cfg: ExampleConfig)
11+
self.hidden_dim = cfg.hidden_dim
12+
13+
This config can then be used in the construction of the model to provide parameters.
14+
Similar approach as done in Fairseq: https://github.com/facebookresearch/fairseq/blob/main/fairseq/dataclass/configs.py
15+
"""
16+
17+
from __future__ import annotations
18+
from dataclasses import dataclass, fields
19+
import typeguard
20+
21+
22+
@dataclass
23+
class ModelConfiguration:
24+
"""
25+
Base dataclass for configuration of different primitives, parts and assemblies.
26+
Enforces type checking for the creation of derived configs.
27+
"""
28+
29+
def _validate_types(self) -> None:
30+
for field in fields(type(self)):
31+
try:
32+
typeguard.check_type(getattr(self, field.name), field.type)
33+
except typeguard.TypeCheckError as exc:
34+
raise typeguard.TypeCheckError(f'In field "{field.name}" of "{type(self)}": {exc}')
35+
36+
def __post_init__(self) -> None:
37+
self._validate_types()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
typeguard

tests/test_config.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from dataclasses import dataclass
2+
import pytest
3+
4+
from i6_models.config import ModelConfiguration
5+
6+
7+
def test_simple_configuration():
8+
@dataclass
9+
class TestConfiguration(ModelConfiguration):
10+
num_layers: int = 5
11+
hidden_dim: int = 256
12+
name: str = "Cool Model Configuration"
13+
14+
test_cfg = TestConfiguration(num_layers=12, name="Even Cooler Model Configuration")
15+
assert test_cfg.hidden_dim == 256
16+
assert test_cfg.name == "Even Cooler Model Configuration"
17+
assert test_cfg.num_layers == 12
18+
test_cfg.num_layers = 7
19+
assert test_cfg.num_layers == 7
20+
21+
22+
def test_nested_configuration():
23+
@dataclass
24+
class TestConfiguration(ModelConfiguration):
25+
num_layers: int = 5
26+
hidden_dim: int = 256
27+
name: str = "Cool Model Configuration"
28+
29+
@dataclass
30+
class TestNestedConfiguration(ModelConfiguration):
31+
encoder_config: TestConfiguration = TestConfiguration(num_layers=4, hidden_dim=3, name="encoder_config")
32+
decoder_config: TestConfiguration = TestConfiguration(num_layers=6, hidden_dim=5, name="decoder_config")
33+
34+
dec_cfg = TestConfiguration()
35+
test_cfg = TestNestedConfiguration(decoder_config=dec_cfg)
36+
37+
assert test_cfg.encoder_config.num_layers == 4
38+
assert test_cfg.encoder_config.hidden_dim == 3
39+
assert test_cfg.encoder_config.name == "encoder_config"
40+
assert test_cfg.decoder_config.num_layers == 5
41+
assert test_cfg.decoder_config.hidden_dim == 256
42+
assert test_cfg.decoder_config.name == "Cool Model Configuration"
43+
test_cfg.encoder_config = TestConfiguration(num_layers=1, hidden_dim=2, name="better_encoder_config")
44+
assert test_cfg.encoder_config.num_layers == 1
45+
assert test_cfg.encoder_config.hidden_dim == 2
46+
assert test_cfg.encoder_config.name == "better_encoder_config"
47+
48+
49+
def test_config_typing():
50+
@dataclass
51+
class TestConfiguration(ModelConfiguration):
52+
num_layers: int = 4
53+
hidden_dim: int = 13
54+
name: str = "Cool Model Configuration"
55+
56+
from typeguard import TypeCheckError
57+
58+
TestConfiguration(num_layers=2, hidden_dim=1)
59+
with pytest.raises(TypeCheckError):
60+
TestConfiguration(num_layers=2.0, hidden_dim="One")

0 commit comments

Comments
 (0)