Skip to content

Commit 2cdbc4b

Browse files
JackTemakialbertzAtticus1806
authored
add unhashed_package_root and PartialImport (#157)
* add unhashed_package_root and PartialImport * Apply suggestions from code review Co-authored-by: Albert Zeyer <[email protected]> * make first parameter positional * restructure partial import string * Apply suggestions from code review Co-authored-by: Benedikt Hilmes <[email protected]> * better examples in docstring * re-add self.package, update docstring --------- Co-authored-by: Albert Zeyer <[email protected]> Co-authored-by: Benedikt Hilmes <[email protected]>
1 parent 4dd6fd0 commit 2cdbc4b

File tree

1 file changed

+90
-5
lines changed

1 file changed

+90
-5
lines changed

common/setups/serialization.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
"""
44

55
from __future__ import annotations
6-
from typing import Any, Union, Optional, List
6+
from typing import Any, Dict, Union, Optional, List
77
from types import FunctionType
8+
import string
89
import sys
910
import textwrap
1011

1112
from sisyphus import tk
1213
from sisyphus.hash import sis_hash_helper, short_hash
1314
from sisyphus.delayed_ops import DelayedBase
1415

15-
from i6_core.util import uopen
16+
from i6_core.util import uopen, instanciate_delayed
1617

1718

1819
class SerializerObject(DelayedBase):
@@ -72,16 +73,24 @@ class Import(SerializerObject):
7273
def __init__(
7374
self,
7475
code_object_path: Union[str, FunctionType, Any],
75-
import_as: Optional[str] = None,
7676
*,
77+
unhashed_package_root: Optional[str] = None,
78+
import_as: Optional[str] = None,
7779
use_for_hash: bool = True,
7880
ignore_import_as_for_hash: bool = False,
7981
):
8082
"""
81-
:param code_object_path: e.g. `i6_experiments.users.username.my_rc_files.SomeNiceASRModel`.
83+
:param code_object_path: e.g.`i6_experiments.users.username.some_experiment.pytorch_networks.SomeNiceASRModel`.
8284
This can be the object itself, e.g. a function or a class. Then it will use __qualname__ and __module__.
85+
:param unhashed_package_root: The root path to a package, from where relatives paths will be hashed.
86+
Recommended is to use the root folder of an experiment module. E.g.:
87+
`i6_experiments.users.username.some_experiment`
88+
which could be retrieved via `__package__` from a module in the root of the `some_experiment` folder.
89+
In case one wants to avoid hash conflicts this might cause, passing an `ExplicitHash` object to the
90+
same collection as the import is possible.
8391
:param import_as: if given, the code object will be imported as this name
84-
:param use_for_hash:
92+
:param use_for_hash: if False, this import is not hashed when passed to a Collection/Serializer
93+
:param ignore_import_as_for_hash: do not hash `import_as` if set
8594
"""
8695
super().__init__()
8796
if not isinstance(code_object_path, str):
@@ -96,6 +105,14 @@ def __init__(
96105
self.object_name = self.code_object.split(".")[-1]
97106
self.module = ".".join(self.code_object.split(".")[:-1])
98107
self.package = ".".join(self.code_object.split(".")[:-2])
108+
109+
if unhashed_package_root:
110+
if not self.code_object.startswith(unhashed_package_root):
111+
raise ValueError(
112+
f"unhashed_package_root: {unhashed_package_root} is not a prefix of {self.code_object}"
113+
)
114+
self.code_object = self.code_object[len(unhashed_package_root) :]
115+
99116
self.import_as = import_as
100117
self.use_for_hash = use_for_hash
101118
self.ignore_import_as_for_hash = ignore_import_as_for_hash
@@ -112,6 +129,74 @@ def _sis_hash(self):
112129
return sis_hash_helper(self.code_object)
113130

114131

132+
class PartialImport(Import):
133+
"""
134+
Like Import, but for partial callables where certain parameters are given fixed and are hashed.
135+
"""
136+
137+
TEMPLATE = textwrap.dedent(
138+
"""\
139+
${OBJECT_NAME} = __import__("functools").partial(
140+
__import__("${IMPORT_PATH}", fromlist=["${IMPORT_NAME}"]).${IMPORT_NAME},
141+
**${KWARGS}
142+
)
143+
"""
144+
)
145+
146+
def __init__(
147+
self,
148+
*,
149+
code_object_path: Union[str, FunctionType, Any],
150+
unhashed_package_root: str,
151+
hashed_arguments: Dict[str, Any],
152+
unhashed_arguments: Dict[str, Any],
153+
import_as: Optional[str] = None,
154+
use_for_hash: bool = True,
155+
ignore_import_as_for_hash: bool = False,
156+
):
157+
"""
158+
:param code_object_path: e.g.`i6_experiments.users.username.some_experiment.pytorch_networks.SomeNiceASRModel`.
159+
This can be the object itself, e.g. a function or a class. Then it will use __qualname__ and __module__.
160+
:param unhashed_package_root: The root path to a package, from where relatives paths will be hashed.
161+
Recommended is to use the root folder of an experiment module. E.g.:
162+
`i6_experiments.users.username.some_experiment`
163+
which could be retrieved via `__package__` from a module in the root of the `some_experiment` folder.
164+
In case one wants to avoid hash conflicts this might cause, passing an `ExplicitHash` object to the
165+
same collection as the import is possible.
166+
:param hashed_arguments: argument dictionary for addition partial arguments to set to the callable.
167+
Will be serialized as dict into the config, so make sure to use only serializable/parseable content
168+
:param unhashed_arguments: same as above, but does not influence the hash
169+
:param import_as: if given, the code object will be imported as this name
170+
:param use_for_hash: if False, this module is not hashed when passed to a Collection/Serializer
171+
:param ignore_import_as_for_hash: do not hash `import_as` if set
172+
"""
173+
174+
super().__init__(
175+
code_object_path=code_object_path,
176+
unhashed_package_root=unhashed_package_root,
177+
import_as=import_as,
178+
use_for_hash=use_for_hash,
179+
ignore_import_as_for_hash=ignore_import_as_for_hash,
180+
)
181+
self.hashed_arguments = hashed_arguments
182+
self.unhashed_arguments = unhashed_arguments
183+
184+
def get(self) -> str:
185+
arguments = {**self.unhashed_arguments, **self.hashed_arguments}
186+
return string.Template(self.TEMPLATE).substitute(
187+
{
188+
"KWARGS": str(instanciate_delayed(arguments)),
189+
"IMPORT_PATH": self.module,
190+
"IMPORT_NAME": self.object_name,
191+
"OBJECT_NAME": self.import_as if self.import_as is not None else self.object_name,
192+
}
193+
)
194+
195+
def _sis_hash(self):
196+
super_hash = super()._sis_hash()
197+
return sis_hash_helper({"import": super_hash, "hashed_arguments": self.hashed_arguments})
198+
199+
115200
class ExternalImport(SerializerObject):
116201
"""
117202
Import from e.g. a git repository. For imports within the recipes use "Import".

0 commit comments

Comments
 (0)