Skip to content

Commit 8f0f960

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Refactor Croissant preparation
PiperOrigin-RevId: 675086192
1 parent acdccd4 commit 8f0f960

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def __init__(
197197
if not record_set_ids:
198198
record_set_ids = croissant_utils.get_record_set_ids(self.metadata)
199199
config_names = [
200-
conversion_utils.to_tfds_name(record_set)
201-
for record_set in record_set_ids
200+
conversion_utils.to_tfds_name(record_set_id)
201+
for record_set_id in record_set_ids
202202
]
203203
self.BUILDER_CONFIGS: Sequence[dataset_builder.BuilderConfig] = [ # pylint: disable=invalid-name
204204
dataset_builder.BuilderConfig(name=config_name)

tensorflow_datasets/scripts/cli/croissant.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -127,40 +127,44 @@ def register_subparser(parsers: argparse._SubParsersAction):
127127
parsers._parser_class = orig_parser_class # pylint: disable=protected-access
128128
parser.add_arguments(CmdArgs, dest='args')
129129
parser.set_defaults(
130-
subparser_fn=lambda args: prepare_croissant_builder(args.args)
130+
subparser_fn=lambda args: prepare_croissant_builders(args.args)
131131
)
132132

133133

134-
def prepare_croissant_builder(args: CmdArgs) -> None:
135-
"""Creates a Croissant Builder and runs the preparation.
134+
def prepare_croissant_builder(
135+
args: CmdArgs, record_set_id: str
136+
) -> croissant_builder.CroissantBuilder:
137+
"""Returns prepared Croissant Builder for the given record set id.
136138
137139
Args:
138140
args: CLI arguments.
141+
record_set_id: Record set id.
139142
"""
140143
builder = croissant_builder.CroissantBuilder(
141144
jsonld=args.jsonld,
142-
record_set_ids=args.record_set_ids,
145+
record_set_ids=[record_set_id],
143146
file_format=args.file_format,
144147
data_dir=args.data_dir,
145148
mapping=args.mapping_json,
146149
overwrite_version=args.overwrite_version,
147150
)
151+
cli_utils.download_and_prepare(
152+
builder=builder,
153+
download_config=None,
154+
download_dir=args.download_dir,
155+
publish_dir=args.publish_dir,
156+
skip_if_published=args.skip_if_published,
157+
overwrite=args.overwrite,
158+
)
159+
return builder
160+
148161

162+
def prepare_croissant_builders(args: CmdArgs):
163+
"""Creates Croissant Builders and prepares them.
164+
165+
Args:
166+
args: CLI arguments.
167+
"""
149168
# Generate each config sequentially.
150-
for config in builder.BUILDER_CONFIGS:
151-
builder_for_config = croissant_builder.CroissantBuilder(
152-
jsonld=args.jsonld,
153-
record_set_ids=[config.name],
154-
file_format=args.file_format,
155-
data_dir=args.data_dir,
156-
mapping=args.mapping_json,
157-
overwrite_version=args.overwrite_version,
158-
)
159-
cli_utils.download_and_prepare(
160-
builder=builder_for_config,
161-
download_config=None,
162-
download_dir=args.download_dir,
163-
publish_dir=args.publish_dir,
164-
skip_if_published=args.skip_if_published,
165-
overwrite=args.overwrite,
166-
)
169+
for record_set_id in args.record_set_ids:
170+
prepare_croissant_builder(args=args, record_set_id=record_set_id)

0 commit comments

Comments
 (0)