3939
4040
4141class SourceConfig (BaseModel ):
42- source : HydraConfig [TensorSource ] # pyright: ignore[reportMissingTypeArgument]
42+ source : HydraConfig [TensorSource ]
4343 index_column : str
4444
4545
@@ -83,7 +83,7 @@ class BasePipelineConfig(BaseModel):
8383 def _validate_inputs (
8484 cls , value : Sequence [dict [str , Any ]]
8585 ) -> Sequence [dict [str , Any ]]:
86- if not mit .all_equal (map (tree_structure , value )): # pyright : ignore[reportArgumentType ]
86+ if not mit .all_equal (map (tree_structure , value )): # ty : ignore[invalid-argument-type ]
8787 msg = "inputs have different structures"
8888 raise ValueError (msg )
8989
@@ -171,7 +171,7 @@ def get_batch(
171171 self ,
172172 index : int | Sequence [int ] | slice | range ,
173173 * ,
174- keys : BatchKeys = BATCH_KEYS_DEFAULT ,
174+ keys : BatchKeys = BATCH_KEYS_DEFAULT , # ty: ignore[invalid-parameter-default]
175175 ) -> Batch :
176176 subkeys : dict [Literal ["data" , "meta" ], set [_ALL_TYPE | str ]] = {
177177 "data" : set (),
@@ -217,7 +217,7 @@ def get_batch(
217217
218218 source_data = {
219219 row [Column .source_id ]: torch .stack ([
220- self ._get_source (source )[idxs ] # pyright: ignore[reportUnknownMemberType]
220+ self ._get_source (source )[idxs ]
221221 for (source , idxs ) in zip (
222222 row [Column .source_config ],
223223 row [Column .source_idxs ],
@@ -232,7 +232,7 @@ def get_batch(
232232 sample_data_cols = (
233233 pl .all ()
234234 if _ALL in subkeys_data
235- else pl .col (subkeys_data - source_data .keys ()) # pyright: ignore[reportArgumentType]
235+ else pl .col (subkeys_data - source_data .keys ())
236236 ).exclude (Column .sample_idx , Column .input_id )
237237
238238 samples_subset = samples .select (sample_data_cols .to_physical ())
@@ -242,7 +242,7 @@ def get_batch(
242242 except TypeError :
243243 sample_data = samples_subset .to_dict (as_series = False )
244244
245- data = TensorDict (source_data | sample_data , batch_size = batch_size ) # pyright: ignore[reportArgumentType]
245+ data = TensorDict (source_data | sample_data , batch_size = batch_size )
246246
247247 else :
248248 data = None
@@ -277,8 +277,8 @@ def __len__(self) -> int:
277277 return len (self .samples )
278278
279279 @cache # noqa: B019
280- def _get_source (self , config : str ) -> TensorSource : # pyright: ignore[reportUnknownParameterType, reportMissingTypeArgument] # noqa: PLR6301
281- return HydraConfig [TensorSource ].model_validate_json (config ).instantiate () # pyright: ignore[reportUnknownVariableType, reportMissingTypeArgument]
280+ def _get_source (self , config : str ) -> TensorSource : # noqa: PLR6301
281+ return HydraConfig [TensorSource ].model_validate_json (config ).instantiate ()
282282
283283 @classmethod
284284 def _build_samples (
@@ -293,33 +293,33 @@ def _build_samples(
293293
294294 case PipelineHydraConfig ():
295295 pipeline = samples .pipeline .instantiate ()
296- executor : Executor | dict [OUTPUT_TYPE , Executor ] | None = tree_map ( # pyright : ignore[reportAssignmentType ]
296+ executor : Executor | dict [OUTPUT_TYPE , Executor ] | None = tree_map ( # ty : ignore[invalid-assignment ]
297297 HydraConfig [Executor ].instantiate ,
298- samples .executor , # pyright : ignore[reportArgumentType ]
298+ samples .executor , # ty : ignore[invalid-argument-type ]
299299 )
300300
301301 pipeline .print_documentation ()
302- inputs = tree_transpose ( # pyright: ignore[reportUnknownVariableType]
303- tree_structure (list (range (len (samples .inputs )))), # pyright : ignore[reportArgumentType ]
304- tree_structure (samples .inputs [0 ]), # pyright : ignore[reportArgumentType ]
305- samples .inputs , # pyright : ignore[reportArgumentType ]
302+ inputs = tree_transpose (
303+ tree_structure (list (range (len (samples .inputs )))), # ty : ignore[invalid-argument-type ]
304+ tree_structure (samples .inputs [0 ]), # ty : ignore[invalid-argument-type ]
305+ samples .inputs , # ty : ignore[invalid-argument-type ]
306306 )
307307
308308 results = pipeline .map (
309- inputs = inputs , # pyright : ignore[reportArgumentType ]
309+ inputs = inputs , # ty : ignore[invalid-argument-type ]
310310 executor = executor ,
311311 ** samples .model_dump (exclude = {"pipeline" , "inputs" , "executor" }),
312312 )
313313
314314 if pipeline .profile :
315315 pipeline .print_profiling_stats ()
316316
317- output_name : str = pipeline .unique_leaf_node .output_name # pyright: ignore[reportUnknownMemberType, reportAssignmentType]
317+ output_name : str = pipeline .unique_leaf_node .output_name
318318
319319 result : pl .DataFrame = (
320320 results [output_name ].output
321321 if results
322- else load_outputs (output_name , run_folder = samples .run_folder ) # pyright : ignore[reportArgumentType ]
322+ else load_outputs (output_name , run_folder = samples .run_folder ) # ty : ignore[invalid-argument-type ]
323323 )
324324
325325 return (
@@ -346,7 +346,7 @@ def _build_sources(cls, sources: SourcesConfig) -> pl.DataFrame:
346346 source_cfg .model_dump (exclude = {"source" })
347347 | {
348348 "id" : source_id ,
349- "config" : source_cfg .source .model_dump_json ( # pyright: ignore[reportUnknownMemberType]
349+ "config" : source_cfg .source .model_dump_json (
350350 by_alias = True
351351 ),
352352 }
@@ -357,8 +357,8 @@ def _build_sources(cls, sources: SourcesConfig) -> pl.DataFrame:
357357 ],
358358 schema_overrides = {Column .input_id : input_id_enum },
359359 )
360- .explode (k )
361- .unnest (k )
362- .select (Column .input_id , pl .exclude (Column .input_id ).name .prefix (f"{ k } ." ))
360+ .explode (k ) # ty: ignore[unresolved-reference]
361+ .unnest (k ) # ty: ignore[unresolved-reference]
362+ .select (Column .input_id , pl .exclude (Column .input_id ).name .prefix (f"{ k } ." )) # ty: ignore[unresolved-reference]
363363 .rechunk ()
364364 )
0 commit comments