@@ -131,6 +131,24 @@ def __init__(
131131 super ().__init__ (wrapped = wrapped , invocation_id = invocation_id )
132132 self ._output_keys = output_keys
133133
134+ def get_artifact (self , key : str ) -> ArtifactFuture :
135+ """Get an artifact future by key.
136+
137+ Args:
138+ key: The key of the artifact future.
139+
140+ Returns:
141+ The artifact future.
142+ """
143+ if key not in self ._output_keys :
144+ raise KeyError (f"Invalid key: { key } " )
145+
146+ return ArtifactFuture (
147+ wrapped = self ._wrapped ,
148+ invocation_id = self ._invocation_id ,
149+ index = self ._output_keys .index (key ),
150+ )
151+
134152 def artifacts (self ) -> StepRunOutputs :
135153 """Get the step run output artifacts.
136154
@@ -159,33 +177,33 @@ def load(self) -> Any:
159177 else :
160178 raise ValueError (f"Invalid step run output: { result } " )
161179
162- def __getitem__ (self , key : Union [ str , int ] ) -> ArtifactFuture :
180+ def __getitem__ (self , key : Any ) -> ArtifactFuture :
163181 """Get an artifact future by key or index.
164182
165183 Args:
166184 key: The key or index of the artifact future.
167185
168186 Raises:
169- ValueError : If the key or index is of an invalid type .
187+ TypeError : If the key is not an integer .
170188 IndexError: If the index is out of range.
171189
172190 Returns:
173191 The artifact future.
174192 """
175- if isinstance (key , str ):
176- index = self . _output_keys . index (key )
177- elif isinstance ( key , int ):
178- index = key
179- else :
180- raise ValueError ( f"Invalid key type: { type ( key ) } " )
193+ if not isinstance (key , int ):
194+ raise TypeError ( f"Invalid key type: { type (key ) } " )
195+
196+ # Convert to positive index if necessary
197+ if key < 0 :
198+ key += len ( self . _output_keys )
181199
182- if index > len (self ._output_keys ):
183- raise IndexError (f"Index out of range: { index } " )
200+ if key > len (self ._output_keys ):
201+ raise IndexError (f"Index out of range: { key } " )
184202
185203 return ArtifactFuture (
186204 wrapped = self ._wrapped ,
187205 invocation_id = self ._invocation_id ,
188- index = index ,
206+ index = key ,
189207 )
190208
191209 def __iter__ (self ) -> Any :
0 commit comments