Skip to content

Commit 2d20ada

Browse files
author
Nick Manganelli
committed
Move valid_format to class method, pre-commit fixes
1 parent 9f4b0b5 commit 2d20ada

File tree

3 files changed

+35
-40
lines changed

3 files changed

+35
-40
lines changed

src/coffea/dataset_tools/filespec.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ def num_selected_entries(self) -> int | None:
199199
total += v.num_selected_entries
200200
return total
201201

202-
def limit_steps(self, max_steps: int | slice | None, per_file: bool = False) -> Self:
202+
def limit_steps(
203+
self, max_steps: int | slice | None, per_file: bool = False
204+
) -> Self:
203205
"""Limit the steps. pass per_file=True to limit steps per file, otherwise limits across all files cumulatively"""
204206

205207
if max_steps is None:
@@ -208,7 +210,10 @@ def limit_steps(self, max_steps: int | slice | None, per_file: bool = False) ->
208210
return type(self)({k: v.limit_steps(max_steps) for k, v in self.items()})
209211
else:
210212
from coffea.dataset_tools.manipulations import _concatenated_step_slice
211-
steps_by_file = _concatenated_step_slice({k: v.steps for k, v in self.items()}, max_steps)
213+
214+
steps_by_file = _concatenated_step_slice(
215+
{k: v.steps for k, v in self.items()}, max_steps
216+
)
212217
new_dict = {}
213218
for k, v in self.items():
214219
if len(steps_by_file[k]) > 0:
@@ -398,6 +403,12 @@ def _check_form(self) -> bool | None:
398403
else:
399404
return None
400405

406+
def _valid_format(self) -> bool:
407+
_formats = {"root", "parquet"}
408+
return self.format in _formats or all(
409+
fmt in _formats for fmt in self.format.split("|")
410+
)
411+
401412
def set_check_format(self) -> bool:
402413
"""Set and/or alidate the format if manually specified"""
403414
if self.format is None:
@@ -411,9 +422,7 @@ def set_check_format(self) -> bool:
411422
self.format = "|".join(union)
412423

413424
# validate the format, if present
414-
if not ModelFactory.valid_format(self.format):
415-
return False
416-
return True
425+
return self._valid_format()
417426

418427
@model_validator(mode="after")
419428
def post_validate(self) -> Self:
@@ -523,17 +532,17 @@ def steps(self) -> dict[str, list[StepPair]] | None:
523532
"""Get the steps per dataset file, if available."""
524533
return {k: v.steps for k, v in self.items()}
525534

526-
def limit_steps(
527-
self, max_steps: int | slice, per_file: bool = False
528-
) -> Self:
535+
def limit_steps(self, max_steps: int | slice, per_file: bool = False) -> Self:
529536
"""Limit the steps"""
530537
spec = copy.deepcopy(self)
531538
# handle both per_file True and False by passthrough
532539
for k, v in spec.items():
533540
spec[k] = v.limit_steps(max_steps, per_file=per_file)
534541
return type(self)(spec)
535542

536-
def limit_files(self, max_files: int | slice | None, per_dataset: bool = True) -> Self:
543+
def limit_files(
544+
self, max_files: int | slice | None, per_dataset: bool = True
545+
) -> Self:
537546
"""Limit the number of files."""
538547
spec = copy.deepcopy(self)
539548
if per_dataset:
@@ -595,21 +604,9 @@ def identify_file_format(name_or_directory: str) -> str:
595604

596605

597606
class ModelFactory:
598-
_formats = {"root", "parquet"}
599-
600607
def __init__(self):
601608
pass
602609

603-
@classmethod
604-
def valid_format(cls, format: str | DatasetSpec) -> bool:
605-
if isinstance(format, DatasetSpec):
606-
test_format = format.format
607-
else:
608-
test_format = format
609-
return test_format in cls._formats or all(
610-
fmt in cls._formats for fmt in test_format.split("|")
611-
)
612-
613610
@classmethod
614611
def attempt_promotion(
615612
cls,
@@ -694,7 +691,7 @@ def dict_to_datasetspec(cls, input: dict[str, Any], verbose=False) -> DatasetSpe
694691
def datasetspec_to_dict(
695692
cls,
696693
input: DatasetSpec,
697-
coerce_filespec_to_dict=True,
694+
coerce_filespec_to_dict: bool = True,
698695
) -> dict[str, Any]:
699696
assert isinstance(
700697
input, DatasetSpec

src/coffea/dataset_tools/manipulations.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
from typing import Any, Callable, Protocol, runtime_checkable
5+
56
try:
67
from typing import Self
78
except ImportError:
@@ -19,21 +20,22 @@
1920
PreprocessedFiles,
2021
)
2122

23+
2224
# protocol for pydantic types that implement limit_files
2325
@runtime_checkable
2426
class LimitFilesProtocol(Protocol):
2527
# handle both limit_files with max_files and max_files + per_dataset
26-
def limit_files(self, max_files: int | slice, per_dataset: bool = True) -> Self:
27-
...
28+
def limit_files(self, max_files: int | slice, per_dataset: bool = True) -> Self: ...
2829
@runtime_checkable
2930
class LimitStepsProtocol(Protocol):
3031
def limit_steps(
3132
self, max_steps: int | slice, per_file: bool = False, per_dataset: bool = True
32-
) -> Self:
33-
...
33+
) -> Self: ...
3434

3535

36-
def max_chunks(fileset: LimitStepsProtocol | FilesetSpec, maxchunks: int | None = None) -> FilesetSpec:
36+
def max_chunks(
37+
fileset: LimitStepsProtocol | FilesetSpec, maxchunks: int | None = None
38+
) -> FilesetSpec:
3739
"""
3840
Modify the input fileset so that only the first "maxchunks" chunks of each dataset will be processed.
3941
@@ -72,7 +74,10 @@ def max_chunks_per_file(
7274
"""
7375
return slice_chunks(fileset, slice(maxchunks), bydataset=False)
7476

75-
def _concatenated_step_slice(stepdict: dict[str, Any], theslice: int | slice) -> dict[str, Any]:
77+
78+
def _concatenated_step_slice(
79+
stepdict: dict[str, Any], theslice: int | slice
80+
) -> dict[str, Any]:
7681
"""
7782
Modify the input step description to only contain the steps specified by the input slice.
7883
@@ -104,11 +109,13 @@ def _concatenated_step_slice(stepdict: dict[str, Any], theslice: int | slice) ->
104109
# 3) repopulate in order, up to maxchunks total
105110
for key, step in kept:
106111
out[key].append(step)
107-
return out # {key: steps for key, steps in out.items() if steps}
112+
return out # {key: steps for key, steps in out.items() if steps}
108113

109114

110115
def slice_chunks(
111-
fileset: LimitStepsProtocol | FilesetSpec, theslice: Any = slice(None), bydataset: bool = True
116+
fileset: LimitStepsProtocol | FilesetSpec,
117+
theslice: Any = slice(None),
118+
bydataset: bool = True,
112119
) -> FilesetSpec:
113120
"""
114121
Modify the input fileset so that only the chunks of each file or each dataset specified by the input slice are processed.

tests/test_dataset_tools_filespec.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -998,13 +998,6 @@ def test_invalid_form(self):
998998

999999
class TestModelFactory:
10001000
"""Test ModelFactory class methods"""
1001-
1002-
def test_valid_format(self):
1003-
"""Test valid_format method"""
1004-
assert ModelFactory.valid_format("root") is True
1005-
assert ModelFactory.valid_format("parquet") is True
1006-
assert ModelFactory.valid_format("invalid") is False
1007-
10081001
@pytest.mark.parametrize(
10091002
"input_dict",
10101003
[
@@ -1395,9 +1388,7 @@ def test_limit_steps_per_file_slicing(self):
13951388
def test_limit_steps_method_chain_slicing(self):
13961389
"""Test limit_steps with slicing"""
13971390
spec = self.get_sliceable_spec()
1398-
limited_spec = spec.limit_steps(1, per_file=True).limit_steps(
1399-
1
1400-
)
1391+
limited_spec = spec.limit_steps(1, per_file=True).limit_steps(1)
14011392
assert limited_spec.steps == {
14021393
"ZJets1": {
14031394
"tests/samples/nano_dy.root": [[0, 5]],

0 commit comments

Comments
 (0)