51
51
from tensorflow_datasets .core .features import text_feature
52
52
from tensorflow_datasets .core .utils import py_utils
53
53
from tensorflow_datasets .core .utils import type_utils
54
+ from tensorflow_datasets .core .utils import version as version_utils
54
55
from tensorflow_datasets .core .utils .lazy_imports_utils import mlcroissant as mlc
55
56
from tensorflow_datasets .core .utils .lazy_imports_utils import pandas as pd
56
57
@@ -75,7 +76,7 @@ def datatype_converter(
75
76
NotImplementedError
76
77
"""
77
78
if field .is_enumeration :
78
- raise NotImplementedError (" Not implemented yet." )
79
+ raise NotImplementedError (' Not implemented yet.' )
79
80
80
81
field_data_type = field .data_type
81
82
@@ -95,7 +96,7 @@ def datatype_converter(
95
96
elif field_data_type == mlc .DataType .IMAGE_OBJECT :
96
97
return image_feature .Image (doc = field .description )
97
98
else :
98
- raise ValueError (f" Unknown data type: { field_data_type } ." )
99
+ raise ValueError (f' Unknown data type: { field_data_type } .' )
99
100
100
101
101
102
def _extract_license (license_ : Any ) -> str | None :
@@ -116,19 +117,19 @@ def _extract_license(license_: Any) -> str | None:
116
117
elif isinstance (license_ , mlc .CreativeWork ):
117
118
possible_fields = [license_ .name , license_ .description , license_ .url ]
118
119
fields = [field for field in possible_fields if field ]
119
- return "[" + "][" .join (fields ) + "]"
120
+ return '[' + '][' .join (fields ) + ']'
120
121
raise ValueError (
121
- f" license_ should be mlc.CreativeWork | str. Got { type (license_ )} "
122
+ f' license_ should be mlc.CreativeWork | str. Got { type (license_ )} '
122
123
)
123
124
124
125
125
126
def _get_license (metadata : Any ) -> str | None :
126
127
"""Gets the license from the metadata."""
127
128
if not isinstance (metadata , mlc .Metadata ):
128
- raise ValueError (f" metadata should be mlc.Metadata. Got { type (metadata )} " )
129
+ raise ValueError (f' metadata should be mlc.Metadata. Got { type (metadata )} ' )
129
130
licenses = metadata .license
130
131
if licenses :
131
- return ", " .join ([_extract_license (l ) for l in licenses if l ])
132
+ return ', ' .join ([_extract_license (l ) for l in licenses if l ])
132
133
return None
133
134
134
135
@@ -146,6 +147,7 @@ def __init__(
146
147
int_dtype : type_utils .TfdsDType | None = np .int64 ,
147
148
float_dtype : type_utils .TfdsDType | None = np .float32 ,
148
149
mapping : Mapping [str , epath .PathLike ] | None = None ,
150
+ overwrite_version : str | None = None ,
149
151
** kwargs : Any ,
150
152
):
151
153
"""Initializes a CroissantBuilder.
@@ -164,7 +166,8 @@ def __init__(
164
166
mapping: Mapping filename->filepath as a Python dict[str, str] to handle
165
167
manual downloads. If `document.csv` is the FileObject and you downloaded
166
168
it to `~/Downloads/document.csv`, you can specify
167
- `mapping={"document.csv": "~/Downloads/document.csv"}`.,
169
+ `mapping={"document.csv": "~/Downloads/document.csv"}`.
170
+ overwrite_version: Semantic version of the dataset to be set.
168
171
**kwargs: kwargs to pass to GeneratorBasedBuilder directly.
169
172
"""
170
173
if mapping is None :
@@ -176,7 +179,9 @@ def __init__(
176
179
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
177
180
# recommended attribute. If the version is unspecified in Croissant, we set
178
181
# it to `1.0.0` in TFDS.
179
- self .VERSION = self .dataset .metadata .version or "1.0.0" # pylint: disable=invalid-name
182
+ self .VERSION = version_utils .Version ( # pylint: disable=invalid-name
183
+ overwrite_version or self .dataset .metadata .version or '1.0.0'
184
+ )
180
185
self .RELEASE_NOTES = {} # pylint: disable=invalid-name
181
186
182
187
if not record_set_ids :
@@ -222,7 +227,7 @@ def get_record_set(self, record_set_id: str):
222
227
if py_utils .make_valid_name (record_set .id ) == record_set_id :
223
228
return record_set
224
229
raise ValueError (
225
- f" Did not find any record set with the name { record_set_id } ."
230
+ f' Did not find any record set with the name { record_set_id } .'
226
231
)
227
232
228
233
def get_features (self ) -> Optional [feature_lib .FeatureConnector ]:
@@ -245,7 +250,7 @@ def _split_generators(
245
250
) -> Dict [splits_lib .Split , split_builder_lib .SplitGenerator ]:
246
251
# This will be updated when partitions are implemented in Croissant, ref to:
247
252
# https://docs.google.com/document/d/1saz3usja6mk5ugJXNF64_uSXsOzIgbIV28_bu1QamVY
248
- return {" default" : self ._generate_examples ()} # pylint: disable=unreachable
253
+ return {' default' : self ._generate_examples ()} # pylint: disable=unreachable
249
254
250
255
def _generate_examples (
251
256
self ,
0 commit comments