@@ -66,7 +66,7 @@ def __init__(
6666 if streams is not None and (
6767 missing_stream_indexes := (
6868 {stream_config .index for stream_config in streams .values ()}
69- - (data_keys := data .keys ())
69+ - (data_keys := set ( data .keys () ))
7070 )
7171 ):
7272 logger .error (
@@ -119,7 +119,7 @@ def streams(self) -> StreamsConfig | None:
119119
120120 @override
121121 def __getitem__ (self , index : int ) -> Batch :
122- return self .get_batch ([index ])[0 ]
122+ return self .get_batch ([index ])[0 ] # ty: ignore[invalid-return-type]
123123
124124 def __getitems__ (self , index : Sequence [int ]) -> Batch : # noqa: PLW3201
125125 return self .get_batch (index )
@@ -134,25 +134,25 @@ def get_batch(
134134 include_streams : bool | None = None ,
135135 include_meta : bool = True ,
136136 ) -> Batch :
137- data = self .data [index ]
137+ data = self .data [index ] # ty: ignore[invalid-argument-type]
138138 meta = self .meta [index ]
139139
140140 match include_streams , self .streams :
141141 case None | True , dict ():
142- stream_data = {stream_id : [] for stream_id in self .streams }
142+ stream_data = {stream_id : [] for stream_id in self .streams } # ty: ignore[not-iterable]
143143
144144 for sample , input_id in zip (data , meta ["input_id" ], strict = True ):
145- for stream_id , stream_config in self .streams .items ():
145+ for stream_id , stream_config in self .streams .items (): # ty: ignore[possibly-missing-attribute]
146146 stream_index = sample [stream_config .index ].tolist ()
147147 source = self ._get_source (stream_id , input_id )
148148 stream_data [stream_id ].append (source [stream_index ])
149149
150150 stream_data = {k : torch .stack (v ) for k , v in stream_data .items ()}
151151
152- if data .is_locked :
153- data = data .clone (recurse = True )
152+ if data .is_locked : # ty: ignore[possibly-missing-attribute]
153+ data = data .clone (recurse = True ) # ty: ignore[unknown-argument]
154154
155- data = data .update (stream_data , inplace = False )
155+ data = data .update (stream_data , inplace = False ) # ty: ignore[possibly-missing-attribute]
156156
157157 case True , None :
158158 msg = "`include_streams` is True but no streams specified"
@@ -173,6 +173,10 @@ def get_batch(
173173
174174 @cachedmethod (lambda self : self ._stream_source_cache )
175175 def _get_source (self , stream_id : str , input_id : str ) -> TensorSource :
176+ if self .streams is None :
177+ msg = "streams not specified"
178+ raise RuntimeError (msg )
179+
176180 return self .streams [stream_id ].sources [input_id ].instantiate ()
177181
178182 @classmethod
@@ -194,7 +198,7 @@ def _build_samples(
194198 )
195199
196200 output_name = pipeline .unique_leaf_node .output_name
197- results = pipeline .map ( # ty: ignore[missing-argument]
201+ results = pipeline .map (
198202 executor = executor , ** samples .model_dump (exclude = {"pipeline" , "executor" })
199203 )
200204
0 commit comments