32
32
33
33
FileFormat = file_adapters .FileFormat
34
34
35
+ DUMMY_DESCRIPTION = "Dummy description."
36
+
35
37
36
38
DUMMY_ENTRIES = [
37
39
{
51
53
]
52
54
53
55
56
+ def _create_mlc_field (
57
+ data_types : mlc .DataType | list [mlc .DataType ],
58
+ description : str ,
59
+ is_array : bool = False ,
60
+ array_shape : str | None = None ,
61
+ repeated : bool = False ,
62
+ source : mlc .Source | None = None ,
63
+ sub_fields : list [mlc .Field ] | None = None ,
64
+ ) -> mlc .Field :
65
+ field = mlc .Field (
66
+ data_types = data_types ,
67
+ description = description ,
68
+ is_array = is_array ,
69
+ array_shape = array_shape ,
70
+ repeated = repeated ,
71
+ sub_fields = sub_fields ,
72
+ )
73
+ if source is not None :
74
+ field .source = source
75
+ return field
76
+
77
+
54
78
@pytest .mark .parametrize (
55
- ["field " , "expected_feature" , "int_dtype" , "float_dtype" ],
79
+ ["mlc_field " , "expected_feature" , "int_dtype" , "float_dtype" ],
56
80
[
57
81
(
58
82
mlc .Field (
121
145
],
122
146
)
123
147
def test_simple_datatype_converter (
124
- field , expected_feature , int_dtype , float_dtype
148
+ mlc_field , expected_feature , int_dtype , float_dtype
125
149
):
126
150
actual_feature = croissant_builder .datatype_converter (
127
- field ,
151
+ mlc_field ,
128
152
int_dtype = int_dtype or np .int64 ,
129
153
float_dtype = float_dtype or np .float32 ,
130
154
)
131
155
assert actual_feature == expected_feature
132
156
133
157
134
- def test_bbox_datatype_converter ():
135
- field = mlc . Field (
158
+ def test_datatype_converter_bbox ():
159
+ field = _create_mlc_field (
136
160
data_types = mlc .DataType .BOUNDING_BOX ,
137
161
description = "Bounding box feature" ,
138
162
source = mlc .Source (format = "XYWH" ),
@@ -142,8 +166,8 @@ def test_bbox_datatype_converter():
142
166
assert actual_feature .bbox_format == bb_utils .BBoxFormat .XYWH
143
167
144
168
145
- def test_bbox_datatype_converter_with_invalid_format ():
146
- field = mlc . Field (
169
+ def test_datatype_converter_bbox_with_invalid_format ():
170
+ field = _create_mlc_field (
147
171
data_types = mlc .DataType .BOUNDING_BOX ,
148
172
description = "Bounding box feature" ,
149
173
source = mlc .Source (format = "InvalidFormat" ),
@@ -153,7 +177,7 @@ def test_bbox_datatype_converter_with_invalid_format():
153
177
154
178
155
179
@pytest .mark .parametrize (
156
- ["field " , "feature_type" , "subfield_types" ],
180
+ ["mlc_field " , "feature_type" , "subfield_types" ],
157
181
[
158
182
(
159
183
mlc .Field (data_types = mlc .DataType .TEXT , description = "Text feature" ),
@@ -219,11 +243,11 @@ def test_bbox_datatype_converter_with_invalid_format():
219
243
),
220
244
],
221
245
)
222
- def test_complex_datatype_converter ( field , feature_type , subfield_types ):
223
- actual_feature = croissant_builder .datatype_converter (field )
224
- assert actual_feature .doc .desc == field .description
246
+ def test_datatype_converter_complex ( mlc_field , feature_type , subfield_types ):
247
+ actual_feature = croissant_builder .datatype_converter (mlc_field )
248
+ assert actual_feature .doc .desc == mlc_field .description
225
249
assert isinstance (actual_feature , feature_type )
226
- if subfield_types :
250
+ if subfield_types is not None :
227
251
for feature_name in actual_feature .keys ():
228
252
assert isinstance (
229
253
actual_feature [feature_name ], subfield_types [feature_name ]
@@ -238,67 +262,134 @@ def test_datatype_converter_none():
238
262
239
263
240
264
def test_multidimensional_datatype_converter ():
241
- field = mlc . Field (
265
+ mlc_field = _create_mlc_field (
242
266
data_types = mlc .DataType .TEXT ,
243
267
description = "Text feature" ,
244
268
is_array = True ,
245
269
array_shape = "2,2" ,
246
270
)
247
- actual_feature = croissant_builder .datatype_converter (field )
271
+ actual_feature = croissant_builder .datatype_converter (mlc_field )
248
272
assert isinstance (actual_feature , tensor_feature .Tensor )
249
273
assert actual_feature .shape == (2 , 2 )
250
274
assert actual_feature .dtype == np .str_
251
275
252
276
253
277
def test_multidimensional_datatype_converter_image_object ():
254
- field = mlc . Field (
278
+ mlc_field = _create_mlc_field (
255
279
data_types = mlc .DataType .IMAGE_OBJECT ,
256
280
description = "Text feature" ,
257
281
is_array = True ,
258
282
array_shape = "2,2" ,
259
283
)
260
- actual_feature = croissant_builder .datatype_converter (field )
284
+ actual_feature = croissant_builder .datatype_converter (mlc_field )
261
285
assert isinstance (actual_feature , sequence_feature .Sequence )
262
286
assert isinstance (actual_feature .feature , sequence_feature .Sequence )
263
287
assert isinstance (actual_feature .feature .feature , image_feature .Image )
264
288
265
289
266
290
def test_multidimensional_datatype_converter_plain_list ():
267
- field = mlc . Field (
291
+ mlc_field = _create_mlc_field (
268
292
data_types = mlc .DataType .TEXT ,
269
293
description = "Text feature" ,
270
294
is_array = True ,
271
295
array_shape = "-1" ,
272
296
)
273
- actual_feature = croissant_builder .datatype_converter (field )
297
+ actual_feature = croissant_builder .datatype_converter (mlc_field )
274
298
assert isinstance (actual_feature , sequence_feature .Sequence )
275
299
assert isinstance (actual_feature .feature , text_feature .Text )
276
300
277
301
278
302
def test_multidimensional_datatype_converter_unknown_shape ():
279
- field = mlc . Field (
303
+ mlc_field = _create_mlc_field (
280
304
data_types = mlc .DataType .TEXT ,
281
305
description = "Text feature" ,
282
306
is_array = True ,
283
307
array_shape = "-1,2" ,
284
308
)
285
- actual_feature = croissant_builder .datatype_converter (field )
309
+ actual_feature = croissant_builder .datatype_converter (mlc_field )
286
310
assert isinstance (actual_feature , sequence_feature .Sequence )
287
311
assert isinstance (actual_feature .feature , sequence_feature .Sequence )
288
312
assert isinstance (actual_feature .feature .feature , text_feature .Text )
289
313
290
314
291
315
def test_sequence_feature_datatype_converter ():
292
- field = mlc . Field (
316
+ mlc_field = _create_mlc_field (
293
317
data_types = mlc .DataType .TEXT ,
294
318
description = "Text feature" ,
295
319
repeated = True ,
296
320
)
297
- actual_feature = croissant_builder .datatype_converter (field )
321
+ actual_feature = croissant_builder .datatype_converter (mlc_field )
298
322
assert isinstance (actual_feature , sequence_feature .Sequence )
299
323
assert isinstance (actual_feature .feature , text_feature .Text )
300
324
301
325
326
+ @pytest .mark .parametrize (
327
+ ["license_" , "expected_license" ],
328
+ [
329
+ ("MIT" , "MIT" ),
330
+ (
331
+ mlc .CreativeWork (
332
+ name = "Creative Commons" ,
333
+ description = "Attribution 4.0 International" ,
334
+ url = "https://creativecommons.org/licenses/by/4.0/" ,
335
+ ),
336
+ (
337
+ "[Creative Commons][Attribution 4.0"
338
+ " International][https://creativecommons.org/licenses/by/4.0/]"
339
+ ),
340
+ ),
341
+ (
342
+ mlc .CreativeWork (
343
+ name = "Creative Commons" ,
344
+ ),
345
+ "[Creative Commons]" ,
346
+ ),
347
+ (
348
+ mlc .CreativeWork (
349
+ description = "Attribution 4.0 International" ,
350
+ ),
351
+ "[Attribution 4.0 International]" ,
352
+ ),
353
+ (
354
+ mlc .CreativeWork (
355
+ url = "https://creativecommons.org/licenses/by/4.0/" ,
356
+ ),
357
+ "[https://creativecommons.org/licenses/by/4.0/]" ,
358
+ ),
359
+ (
360
+ mlc .CreativeWork (),
361
+ "[]" ,
362
+ ),
363
+ ],
364
+ )
365
+ def test_extract_license (license_ , expected_license ):
366
+ actual_license = croissant_builder ._extract_license (license_ )
367
+ assert actual_license == expected_license
368
+
369
+
370
+ def test_extract_license_with_invalid_input ():
371
+ with pytest .raises (
372
+ ValueError , match = "^license_ should be mlc.CreativeWork | str"
373
+ ):
374
+ croissant_builder ._extract_license (123 )
375
+
376
+
377
+ def test_get_license ():
378
+ metadata = mlc .Metadata (license = ["MIT" , "Apache 2.0" ])
379
+ actual_license = croissant_builder ._get_license (metadata )
380
+ assert actual_license == "MIT, Apache 2.0"
381
+
382
+
383
+ def test_get_license_with_invalid_input ():
384
+ with pytest .raises (ValueError , match = "metadata should be mlc.Metadata" ):
385
+ croissant_builder ._get_license (123 )
386
+
387
+
388
+ def test_get_license_with_empty_license ():
389
+ metadata = mlc .Metadata (license = [])
390
+ assert croissant_builder ._get_license (metadata ) is None
391
+
392
+
302
393
def test_version_converter (tmp_path ):
303
394
with testing .dummy_croissant_file (version = "1.0" ) as croissant_file :
304
395
builder = croissant_builder .CroissantBuilder (
@@ -344,7 +435,7 @@ def test_croissant_builder(crs_builder):
344
435
crs_builder ._info ().citation
345
436
== "@article{dummyarticle, title={title}, author={author}, year={2020}}"
346
437
)
347
- assert crs_builder ._info ().description == "Dummy description."
438
+ assert crs_builder ._info ().description == DUMMY_DESCRIPTION
348
439
assert crs_builder ._info ().homepage == "https://dummy_url"
349
440
assert crs_builder ._info ().redistribution_info .license == "Public"
350
441
# One `split` and one `jsonl` recordset.
0 commit comments