Skip to content

Commit 72269b5

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Add overwrite_version to Croissant builder.
PiperOrigin-RevId: 633548755
1 parent a586b8e commit 72269b5

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from tensorflow_datasets.core.features import text_feature
5252
from tensorflow_datasets.core.utils import py_utils
5353
from tensorflow_datasets.core.utils import type_utils
54+
from tensorflow_datasets.core.utils import version as version_utils
5455
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
5556
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
5657

@@ -75,7 +76,7 @@ def datatype_converter(
7576
NotImplementedError
7677
"""
7778
if field.is_enumeration:
78-
raise NotImplementedError("Not implemented yet.")
79+
raise NotImplementedError('Not implemented yet.')
7980

8081
field_data_type = field.data_type
8182

@@ -95,7 +96,7 @@ def datatype_converter(
9596
elif field_data_type == mlc.DataType.IMAGE_OBJECT:
9697
return image_feature.Image(doc=field.description)
9798
else:
98-
raise ValueError(f"Unknown data type: {field_data_type}.")
99+
raise ValueError(f'Unknown data type: {field_data_type}.')
99100

100101

101102
def _extract_license(license_: Any) -> str | None:
@@ -116,19 +117,19 @@ def _extract_license(license_: Any) -> str | None:
116117
elif isinstance(license_, mlc.CreativeWork):
117118
possible_fields = [license_.name, license_.description, license_.url]
118119
fields = [field for field in possible_fields if field]
119-
return "[" + "][".join(fields) + "]"
120+
return '[' + ']['.join(fields) + ']'
120121
raise ValueError(
121-
f"license_ should be mlc.CreativeWork | str. Got {type(license_)}"
122+
f'license_ should be mlc.CreativeWork | str. Got {type(license_)}'
122123
)
123124

124125

125126
def _get_license(metadata: Any) -> str | None:
126127
"""Gets the license from the metadata."""
127128
if not isinstance(metadata, mlc.Metadata):
128-
raise ValueError(f"metadata should be mlc.Metadata. Got {type(metadata)}")
129+
raise ValueError(f'metadata should be mlc.Metadata. Got {type(metadata)}')
129130
licenses = metadata.license
130131
if licenses:
131-
return ", ".join([_extract_license(l) for l in licenses if l])
132+
return ', '.join([_extract_license(l) for l in licenses if l])
132133
return None
133134

134135

@@ -146,6 +147,7 @@ def __init__(
146147
int_dtype: type_utils.TfdsDType | None = np.int64,
147148
float_dtype: type_utils.TfdsDType | None = np.float32,
148149
mapping: Mapping[str, epath.PathLike] | None = None,
150+
overwrite_version: str | None = None,
149151
**kwargs: Any,
150152
):
151153
"""Initializes a CroissantBuilder.
@@ -164,7 +166,8 @@ def __init__(
164166
mapping: Mapping filename->filepath as a Python dict[str, str] to handle
165167
manual downloads. If `document.csv` is the FileObject and you downloaded
166168
it to `~/Downloads/document.csv`, you can specify
167-
`mapping={"document.csv": "~/Downloads/document.csv"}`.,
169+
`mapping={"document.csv": "~/Downloads/document.csv"}`.
170+
overwrite_version: Semantic version of the dataset to be set.
168171
**kwargs: kwargs to pass to GeneratorBasedBuilder directly.
169172
"""
170173
if mapping is None:
@@ -176,7 +179,9 @@ def __init__(
176179
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
177180
# recommended attribute. If the version is unspecified in Croissant, we set
178181
# it to `1.0.0` in TFDS.
179-
self.VERSION = self.dataset.metadata.version or "1.0.0" # pylint: disable=invalid-name
182+
self.VERSION = version_utils.Version( # pylint: disable=invalid-name
183+
overwrite_version or self.dataset.metadata.version or '1.0.0'
184+
)
180185
self.RELEASE_NOTES = {} # pylint: disable=invalid-name
181186

182187
if not record_set_ids:
@@ -222,7 +227,7 @@ def get_record_set(self, record_set_id: str):
222227
if py_utils.make_valid_name(record_set.id) == record_set_id:
223228
return record_set
224229
raise ValueError(
225-
f"Did not find any record set with the name {record_set_id}."
230+
f'Did not find any record set with the name {record_set_id}.'
226231
)
227232

228233
def get_features(self) -> Optional[feature_lib.FeatureConnector]:
@@ -245,7 +250,7 @@ def _split_generators(
245250
) -> Dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
246251
# This will be updated when partitions are implemented in Croissant, ref to:
247252
# https://docs.google.com/document/d/1saz3usja6mk5ugJXNF64_uSXsOzIgbIV28_bu1QamVY
248-
return {"default": self._generate_examples()} # pylint: disable=unreachable
253+
return {'default': self._generate_examples()} # pylint: disable=unreachable
249254

250255
def _generate_examples(
251256
self,

tensorflow_datasets/scripts/cli/croissant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class CmdArgs:
6060
skip_if_published: If the dataset with the same version and config is
6161
already published, then it will not be regenerated.
6262
overwrite: Delete pre-existing dataset if it exists.
63+
overwrite_version: Semantic version of the dataset to be set.
6364
"""
6465

6566
jsonld: epath.PathLike
@@ -81,6 +82,7 @@ class CmdArgs:
8182
publish_dir: epath.PathLike | None = None
8283
skip_if_published: bool = False
8384
overwrite: bool = False
85+
overwrite_version: str | None = None
8486

8587

8688
def register_subparser(parsers: argparse._SubParsersAction):
@@ -123,6 +125,7 @@ def prepare_croissant_builder(args: CmdArgs) -> None:
123125
file_format=args.file_format,
124126
data_dir=args.data_dir,
125127
mapping=mapping,
128+
overwrite_version=args.overwrite_version,
126129
)
127130
cli_utils.download_and_prepare(
128131
builder=builder,

0 commit comments

Comments
 (0)