40
40
import json
41
41
from typing import Any
42
42
43
+ from etils import enp
43
44
from etils import epath
44
45
import numpy as np
45
46
from tensorflow_datasets .core import dataset_builder
@@ -79,8 +80,7 @@ def _strip_record_set_prefix(
79
80
def array_datatype_converter (
80
81
feature : type_utils .TfdsDType | feature_lib .FeatureConnector | None ,
81
82
field : mlc .Field ,
82
- int_dtype : type_utils .TfdsDType = np .int64 ,
83
- float_dtype : type_utils .TfdsDType = np .float32 ,
83
+ dtype_mapping : Mapping [type_utils .TfdsDType , type_utils .TfdsDType ],
84
84
):
85
85
"""Includes the given feature in a sequence or tensor feature.
86
86
@@ -91,32 +91,28 @@ def array_datatype_converter(
91
91
Args:
92
92
feature: The inner feature to include in a sequence or tensor feature.
93
93
field: The mlc.Field object.
94
- int_dtype: The dtype to use for TFDS integer features. Defaults to np.int64.
95
- float_dtype: The dtype to use for TFDS float features. Defaults to
96
- np.float32.
94
+ dtype_mapping: A mapping of dtypes to the corresponding dtypes that will be
95
+ used in TFDS.
97
96
98
97
Returns:
99
98
A sequence or tensor feature including the inner feature.
100
99
"""
101
- dtype_mapping = {
102
- int : int_dtype ,
103
- float : float_dtype ,
104
- bool : np .bool_ ,
105
- bytes : np .str_ ,
106
- }
107
- dtype = dtype_mapping .get (field .data_type , None )
100
+ field_dtype = None
101
+ if field .data_type in dtype_mapping :
102
+ field_dtype = dtype_mapping [field .data_type ]
103
+ elif enp .lazy .is_np_dtype (field .data_type ):
104
+ field_dtype = field .data_type
105
+
108
106
if len (field .array_shape_tuple ) == 1 :
109
107
return sequence_feature .Sequence (feature , doc = field .description )
110
- elif (- 1 in field .array_shape_tuple ) or (
111
- field .data_type not in dtype_mapping
112
- ):
108
+ elif (- 1 in field .array_shape_tuple ) or (field_dtype is None ):
113
109
for _ in range (len (field .array_shape_tuple )):
114
110
feature = sequence_feature .Sequence (feature , doc = field .description )
115
111
return feature
116
112
else :
117
113
return tensor_feature .Tensor (
118
114
shape = field .array_shape_tuple ,
119
- dtype = dtype ,
115
+ dtype = field_dtype ,
120
116
doc = field .description ,
121
117
)
122
118
@@ -142,8 +138,15 @@ def datatype_converter(
142
138
"""
143
139
if field .is_enumeration :
144
140
raise NotImplementedError ('Not implemented yet.' )
141
+ dtype_mapping = {
142
+ bool : np .bool_ ,
143
+ bytes : np .str_ ,
144
+ float : float_dtype ,
145
+ int : int_dtype ,
146
+ }
145
147
146
148
field_data_type = field .data_type
149
+
147
150
if not field_data_type :
148
151
# Fields with sub fields are of type None
149
152
if field .sub_fields :
@@ -158,14 +161,12 @@ def datatype_converter(
158
161
)
159
162
else :
160
163
feature = None
161
- elif field_data_type == int :
162
- feature = int_dtype
163
- elif field_data_type == float :
164
- feature = float_dtype
165
- elif field_data_type == bool :
166
- feature = np .bool_
167
164
elif field_data_type == bytes :
168
165
feature = text_feature .Text (doc = field .description )
166
+ elif field_data_type in dtype_mapping :
167
+ feature = dtype_mapping [field_data_type ]
168
+ elif enp .lazy .is_np_dtype (field_data_type ):
169
+ feature = field_data_type
169
170
# We return a text feature for mlc.DataType.DATE features.
170
171
elif field_data_type == pd .Timestamp :
171
172
feature = text_feature .Text (doc = field .description )
@@ -183,8 +184,7 @@ def datatype_converter(
183
184
feature = array_datatype_converter (
184
185
feature = feature ,
185
186
field = field ,
186
- int_dtype = int_dtype ,
187
- float_dtype = float_dtype ,
187
+ dtype_mapping = dtype_mapping ,
188
188
)
189
189
# If the field is repeated, we return a sequence feature. `field.repeated` is
190
190
# deprecated starting from Croissant 1.1, but we still support it for
0 commit comments