File tree Expand file tree Collapse file tree 2 files changed +35
-1
lines changed
tensorflow_datasets/core/utils Expand file tree Collapse file tree 2 files changed +35
-1
lines changed Original file line number Diff line number Diff line change @@ -108,7 +108,7 @@ def get_split_recordset(
108
108
if not record_sets :
109
109
raise ValueError ("field {field.id} has no RecordSet" )
110
110
referenced_record_set = record_sets [0 ]
111
- if str ( mlc .DataType .SPLIT ) in referenced_record_set .data_types :
111
+ if mlc .DataType .SPLIT in referenced_record_set .data_types :
112
112
return SplitReference (referenced_record_set , field )
113
113
return None
114
114
Original file line number Diff line number Diff line change @@ -57,6 +57,40 @@ def test_get_record_set_ids():
57
57
assert record_set_ids == ['record_set_1' ]
58
58
59
59
60
+ def test_get_split_recordset ():
61
+ record_sets = [
62
+ mlc .RecordSet (
63
+ id = 'records' ,
64
+ fields = [
65
+ mlc .Field (
66
+ id = 'records/split' ,
67
+ data_types = [mlc .DataType .TEXT ],
68
+ references = mlc .Source (field = 'splits/name' ),
69
+ )
70
+ ],
71
+ ),
72
+ mlc .RecordSet (
73
+ id = 'splits' ,
74
+ key = 'name' ,
75
+ data_types = [mlc .DataType .SPLIT ],
76
+ fields = [
77
+ mlc .Field (
78
+ id = 'splits/name' , name = 'name' , data_types = mlc .DataType .TEXT
79
+ )
80
+ ],
81
+ data = [{'name' : 'train' }, {'name' : 'test' }],
82
+ ),
83
+ ]
84
+ metadata = mlc .Metadata (name = 'dummy' , url = 'dum.my' , record_sets = record_sets )
85
+ dataset = mlc .Dataset .from_metadata (metadata )
86
+ split_recordset = croissant_utils .get_split_recordset (
87
+ record_set = dataset .metadata .record_sets [0 ], metadata = metadata
88
+ )
89
+ assert split_recordset is not None
90
+ assert split_recordset .reference_field .id == 'records/split'
91
+ assert split_recordset .split_record_set .id == 'splits'
92
+
93
+
60
94
def test_get_split_recordset_with_no_split_recordset ():
61
95
record_sets = [
62
96
mlc .RecordSet (
You can’t perform that action at this time.
0 commit comments