Skip to content

Commit 3ce5439

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Move Croissant related functions to croissant_utils.py
PiperOrigin-RevId: 642256699
1 parent 8731daa commit 3ce5439

File tree

5 files changed

+79
-42
lines changed

5 files changed

+79
-42
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from tensorflow_datasets.core.features import image_feature
5050
from tensorflow_datasets.core.features import sequence_feature
5151
from tensorflow_datasets.core.features import text_feature
52+
from tensorflow_datasets.core.utils import croissant_utils
5253
from tensorflow_datasets.core.utils import huggingface_utils
5354
from tensorflow_datasets.core.utils import type_utils
5455
from tensorflow_datasets.core.utils import version as version_utils
@@ -173,9 +174,7 @@ def __init__(
173174
if mapping is None:
174175
mapping = {}
175176
self.dataset = mlc.Dataset(jsonld, mapping=mapping)
176-
self.name = huggingface_utils.get_tfds_name_from_croissant_dataset(
177-
self.dataset
178-
)
177+
self.name = croissant_utils.get_tfds_dataset_name(self.dataset)
179178
self.metadata = self.dataset.metadata
180179

181180
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# coding=utf-8
2+
# Copyright 2024 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utility functions for croissant_builder."""
17+
18+
from __future__ import annotations
19+
20+
import typing
21+
22+
from tensorflow_datasets.core.utils import huggingface_utils
23+
24+
if typing.TYPE_CHECKING:
25+
# pylint: disable=g-bad-import-order
26+
import mlcroissant as mlc
27+
28+
_HUGGINGFACE_URL_PREFIX = 'https://huggingface.co/datasets/'
29+
30+
31+
def get_dataset_name(dataset: mlc.Dataset) -> str:
32+
"""Returns dataset name of the given MLcroissant dataset."""
33+
if (url := dataset.metadata.url) and url.startswith(_HUGGINGFACE_URL_PREFIX):
34+
return url.removeprefix(_HUGGINGFACE_URL_PREFIX)
35+
return dataset.metadata.name
36+
37+
38+
def get_tfds_dataset_name(dataset: mlc.Dataset) -> str:
39+
"""Returns TFDS compatible dataset name of the given MLcroissant dataset."""
40+
dataset_name = get_dataset_name(dataset)
41+
return huggingface_utils.convert_hf_name(dataset_name)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# coding=utf-8
2+
# Copyright 2024 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import mlcroissant as mlc
17+
import pytest
18+
from tensorflow_datasets.core.utils import croissant_utils
19+
20+
21+
@pytest.mark.parametrize(
22+
'croissant_name,croissant_url,tfds_name',
23+
[
24+
(
25+
'Name+1',
26+
'https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k',
27+
'huggingfaceh4__ultrachat_200k',
28+
),
29+
('Name+1', 'bad_url', 'name_1'),
30+
('Name+1', None, 'name_1'),
31+
],
32+
)
33+
def test_get_tfds_dataset_name(croissant_name, croissant_url, tfds_name):
34+
metadata = mlc.Metadata(name=croissant_name, url=croissant_url)
35+
dataset = mlc.Dataset.from_metadata(metadata)
36+
assert croissant_utils.get_tfds_dataset_name(dataset) == tfds_name

tensorflow_datasets/core/utils/huggingface_utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from collections.abc import Mapping, Sequence
2121
import datetime
22-
import typing
2322
from typing import Any, Type, TypeVar
2423

2524
from etils import epath
@@ -33,10 +32,6 @@
3332
from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets
3433
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3534

36-
if typing.TYPE_CHECKING:
37-
# pylint: disable=g-bad-import-order
38-
import mlcroissant as mlc
39-
4035

4136
_HF_DTYPE_TO_NP_DTYPE = immutabledict.immutabledict({
4237
'bool': np.bool_,
@@ -249,16 +244,6 @@ def convert_hf_value(
249244
)
250245

251246

252-
def get_tfds_name_from_croissant_dataset(dataset: mlc.Dataset) -> str:
253-
"""Returns TFDS compatible dataset name of the given MLcroissant dataset."""
254-
if (url := dataset.metadata.url) and url.startswith(
255-
'https://huggingface.co/datasets/'
256-
):
257-
url_suffix = url.removeprefix('https://huggingface.co/datasets/')
258-
return convert_hf_name(url_suffix)
259-
return convert_hf_name(dataset.metadata.name)
260-
261-
262247
def convert_hf_name(hf_name: str) -> str:
263248
"""Converts Huggingface name to a TFDS compatible dataset name.
264249

tensorflow_datasets/core/utils/huggingface_utils_test.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from tensorflow_datasets.core import lazy_imports_lib
2323
from tensorflow_datasets.core import registered
2424
from tensorflow_datasets.core.utils import huggingface_utils
25-
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
2625
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
2726

2827

@@ -217,29 +216,6 @@ def test_convert_value(hf_value, feature, expected_value):
217216
assert huggingface_utils.convert_hf_value(hf_value, feature) == expected_value
218217

219218

220-
@pytest.mark.parametrize(
221-
'croissant_name,croissant_url,tfds_name',
222-
[
223-
(
224-
'Name+1',
225-
'https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k',
226-
'huggingfaceh4__ultrachat_200k',
227-
),
228-
('Name+1', 'bad_url', 'name_1'),
229-
('Name+1', None, 'name_1'),
230-
],
231-
)
232-
def test_get_tfds_name_from_croissant_dataset(
233-
croissant_name, croissant_url, tfds_name
234-
):
235-
metadata = mlc.Metadata(name=croissant_name, url=croissant_url)
236-
dataset = mlc.Dataset.from_metadata(metadata)
237-
assert (
238-
huggingface_utils.get_tfds_name_from_croissant_dataset(dataset)
239-
== tfds_name
240-
)
241-
242-
243219
@pytest.mark.parametrize(
244220
'hf_name,tfds_name',
245221
[

0 commit comments

Comments
 (0)