52
52
from tensorflow_datasets .core .features import features_dict
53
53
from tensorflow_datasets .core .features import image_feature
54
54
from tensorflow_datasets .core .features import sequence_feature
55
+ from tensorflow_datasets .core .features import tensor_feature
55
56
from tensorflow_datasets .core .features import text_feature
56
57
from tensorflow_datasets .core .utils import conversion_utils
57
58
from tensorflow_datasets .core .utils import croissant_utils
@@ -75,6 +76,51 @@ def _strip_record_set_prefix(
75
76
}
76
77
77
78
79
+ def array_datatype_converter (
80
+ feature : type_utils .TfdsDType | feature_lib .FeatureConnector | None ,
81
+ field : mlc .Field ,
82
+ int_dtype : type_utils .TfdsDType = np .int64 ,
83
+ float_dtype : type_utils .TfdsDType = np .float32 ,
84
+ ):
85
+ """Includes the given feature in a sequence or tensor feature.
86
+
87
+ Single-dimensional arrays are converted to sequences. Multi-dimensional arrays
88
+ with unknown dimensions, or with non-native dtypes are converted to sequences
89
+ of sequences. Otherwise, they are converted to tensors.
90
+
91
+ Args:
92
+ feature: The inner feature to include in a sequence or tensor feature.
93
+ field: The mlc.Field object.
94
+ int_dtype: The dtype to use for TFDS integer features. Defaults to np.int64.
95
+ float_dtype: The dtype to use for TFDS float features. Defaults to
96
+ np.float32.
97
+
98
+ Returns:
99
+ A sequence or tensor feature including the inner feature.
100
+ """
101
+ dtype_mapping = {
102
+ int : int_dtype ,
103
+ float : float_dtype ,
104
+ bool : np .bool_ ,
105
+ bytes : np .str_ ,
106
+ }
107
+ dtype = dtype_mapping .get (field .data_type , None )
108
+ if len (field .array_shape_tuple ) == 1 :
109
+ return sequence_feature .Sequence (feature , doc = field .description )
110
+ elif (- 1 in field .array_shape_tuple ) or (
111
+ field .data_type not in dtype_mapping
112
+ ):
113
+ for _ in range (len (field .array_shape_tuple )):
114
+ feature = sequence_feature .Sequence (feature , doc = field .description )
115
+ return feature
116
+ else :
117
+ return tensor_feature .Tensor (
118
+ shape = field .array_shape_tuple ,
119
+ dtype = dtype ,
120
+ doc = field .description ,
121
+ )
122
+
123
+
78
124
def datatype_converter (
79
125
field : mlc .Field ,
80
126
int_dtype : type_utils .TfdsDType = np .int64 ,
@@ -133,6 +179,16 @@ def datatype_converter(
133
179
else :
134
180
raise ValueError (f'Unknown data type: { field_data_type } .' )
135
181
182
+ if feature and field .is_array :
183
+ feature = array_datatype_converter (
184
+ feature = feature ,
185
+ field = field ,
186
+ int_dtype = int_dtype ,
187
+ float_dtype = float_dtype ,
188
+ )
189
+ # If the field is repeated, we return a sequence feature. `field.repeated` is
190
+ # deprecated starting from Croissant 1.1, but we still support it for
191
+ # backwards compatibility.
136
192
if feature and field .repeated :
137
193
feature = sequence_feature .Sequence (feature , doc = field .description )
138
194
return feature
0 commit comments