Skip to content

Commit 61c3c06

Browse files
author
The TensorFlow Datasets Authors
committed
Dealing with Splits in CroissantBuilder
PiperOrigin-RevId: 653926362
1 parent 3360187 commit 61c3c06

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

tensorflow_datasets/core/utils/croissant_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_split_recordset(
108108
if not record_sets:
109109
raise ValueError("field {field.id} has no RecordSet")
110110
referenced_record_set = record_sets[0]
111-
if str(mlc.DataType.SPLIT) in referenced_record_set.data_types:
111+
if mlc.DataType.SPLIT in referenced_record_set.data_types:
112112
return SplitReference(referenced_record_set, field)
113113
return None
114114

tensorflow_datasets/core/utils/croissant_utils_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,40 @@ def test_get_record_set_ids():
5757
assert record_set_ids == ['record_set_1']
5858

5959

60+
def test_get_split_recordset():
61+
record_sets = [
62+
mlc.RecordSet(
63+
id='records',
64+
fields=[
65+
mlc.Field(
66+
id='records/split',
67+
data_types=[mlc.DataType.TEXT],
68+
references=mlc.Source(field='splits/name'),
69+
)
70+
],
71+
),
72+
mlc.RecordSet(
73+
id='splits',
74+
key='name',
75+
data_types=[mlc.DataType.SPLIT],
76+
fields=[
77+
mlc.Field(
78+
id='splits/name', name='name', data_types=mlc.DataType.TEXT
79+
)
80+
],
81+
data=[{'name': 'train'}, {'name': 'test'}],
82+
),
83+
]
84+
metadata = mlc.Metadata(name='dummy', url='dum.my', record_sets=record_sets)
85+
dataset = mlc.Dataset.from_metadata(metadata)
86+
split_recordset = croissant_utils.get_split_recordset(
87+
record_set=dataset.metadata.record_sets[0], metadata=metadata
88+
)
89+
assert split_recordset is not None
90+
assert split_recordset.reference_field.id == 'records/split'
91+
assert split_recordset.split_record_set.id == 'splits'
92+
93+
6094
def test_get_split_recordset_with_no_split_recordset():
6195
record_sets = [
6296
mlc.RecordSet(

0 commit comments

Comments
 (0)