Skip to content

Commit fe0804a

Browse files
committed
Improve future indexing
1 parent 4e1d4e6 commit fe0804a

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ def __init__(
131131
super().__init__(wrapped=wrapped, invocation_id=invocation_id)
132132
self._output_keys = output_keys
133133

134+
def get_artifact(self, key: str) -> ArtifactFuture:
135+
"""Get an artifact future by key.
136+
137+
Args:
138+
key: The key of the artifact future.
139+
140+
Returns:
141+
The artifact future.
142+
"""
143+
if key not in self._output_keys:
144+
raise KeyError(f"Invalid key: {key}")
145+
146+
return ArtifactFuture(
147+
wrapped=self._wrapped,
148+
invocation_id=self._invocation_id,
149+
index=self._output_keys.index(key),
150+
)
151+
134152
def artifacts(self) -> StepRunOutputs:
135153
"""Get the step run output artifacts.
136154
@@ -159,33 +177,33 @@ def load(self) -> Any:
159177
else:
160178
raise ValueError(f"Invalid step run output: {result}")
161179

162-
def __getitem__(self, key: Union[str, int]) -> ArtifactFuture:
180+
def __getitem__(self, key: Any) -> ArtifactFuture:
163181
"""Get an artifact future by key or index.
164182
165183
Args:
166184
key: The key or index of the artifact future.
167185
168186
Raises:
169-
ValueError: If the key or index is of an invalid type.
187+
TypeError: If the key is not an integer.
170188
IndexError: If the index is out of range.
171189
172190
Returns:
173191
The artifact future.
174192
"""
175-
if isinstance(key, str):
176-
index = self._output_keys.index(key)
177-
elif isinstance(key, int):
178-
index = key
179-
else:
180-
raise ValueError(f"Invalid key type: {type(key)}")
193+
if not isinstance(key, int):
194+
raise TypeError(f"Invalid key type: {type(key)}")
195+
196+
# Convert to positive index if necessary
197+
if key < 0:
198+
key += len(self._output_keys)
181199

182-
if index > len(self._output_keys):
183-
raise IndexError(f"Index out of range: {index}")
200+
if key > len(self._output_keys):
201+
raise IndexError(f"Index out of range: {key}")
184202

185203
return ArtifactFuture(
186204
wrapped=self._wrapped,
187205
invocation_id=self._invocation_id,
188-
index=index,
206+
index=key,
189207
)
190208

191209
def __iter__(self) -> Any:

0 commit comments

Comments
 (0)