Skip to content

Commit 12ba3fc

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Load Croissant dataset through the args.
PiperOrigin-RevId: 646429405
1 parent 852ee9b commit 12ba3fc

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

tensorflow_datasets/scripts/cli/croissant.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@
2828

2929
import argparse
3030
import dataclasses
31+
import functools
3132
import json
3233
import typing
3334

3435
from etils import epath
36+
import mlcroissant as mlc
3537
import simple_parsing
3638
from tensorflow_datasets.core import file_adapters
3739
from tensorflow_datasets.core.dataset_builders import croissant_builder
40+
from tensorflow_datasets.core.utils import croissant_utils
3841
from tensorflow_datasets.scripts.cli import cli_utils
3942

4043

@@ -84,6 +87,25 @@ class CmdArgs(simple_parsing.helpers.FrozenSerializable):
8487
overwrite: bool = False
8588
overwrite_version: str | None = None
8689

90+
@functools.cached_property
91+
def mapping_json(self) -> dict[str, epath.PathLike]:
92+
if self.mapping:
93+
try:
94+
return json.loads(self.mapping)
95+
except json.JSONDecodeError as e:
96+
raise ValueError(
97+
f'Error parsing mapping parameter: {self.mapping}'
98+
) from e
99+
return {}
100+
101+
@functools.cached_property
102+
def dataset(self) -> mlc.Dataset:
103+
return mlc.Dataset(jsonld=self.jsonld, mapping=self.mapping_json)
104+
105+
@functools.cached_property
106+
def dataset_name(self) -> str:
107+
return croissant_utils.get_dataset_name(self.dataset)
108+
87109

88110
def register_subparser(parsers: argparse._SubParsersAction):
89111
"""Add subparser for `convert_format` command."""
@@ -109,22 +131,12 @@ def prepare_croissant_builder(args: CmdArgs) -> None:
109131
Args:
110132
args: CLI arguments.
111133
"""
112-
if args.mapping:
113-
try:
114-
mapping = json.loads(args.mapping)
115-
except json.JSONDecodeError as e:
116-
raise ValueError(
117-
f'Error parsing mapping parameter: {args.mapping}'
118-
) from e
119-
else:
120-
mapping = None
121-
122134
builder = croissant_builder.CroissantBuilder(
123135
jsonld=args.jsonld,
124136
record_set_ids=args.record_sets or None,
125137
file_format=args.file_format,
126138
data_dir=args.data_dir,
127-
mapping=mapping,
139+
mapping=args.mapping_json,
128140
overwrite_version=args.overwrite_version,
129141
)
130142
cli_utils.download_and_prepare(

0 commit comments

Comments
 (0)