@@ -637,7 +637,7 @@ def initialized(self) -> bool:
637
637
638
638
@property
639
639
def as_json (self ) -> str :
640
- return json_format . MessageToJson (self .as_proto , sort_keys = True )
640
+ return get_dataset_info_json (self .as_proto )
641
641
642
642
def write_to_directory (
643
643
self , dataset_info_dir : epath .PathLike , all_metadata = True
@@ -671,7 +671,7 @@ def write_to_directory(
671
671
672
672
def write_dataset_info_json (self , dataset_info_dir : epath .PathLike ) -> None :
673
673
"""Writes only the dataset_info.json file to the given directory."""
674
- dataset_info_path ( dataset_info_dir ). write_text ( self .as_json )
674
+ write_dataset_info_proto ( self .as_proto , dataset_info_dir = dataset_info_dir )
675
675
676
676
def read_from_directory (self , dataset_info_dir : epath .PathLike ) -> None :
677
677
"""Update DatasetInfo from the metadata files in `dataset_info_dir`.
@@ -852,18 +852,10 @@ def add_tfds_data_source_access(
852
852
dataset_reference:
853
853
url: a URL referring to the TFDS dataset.
854
854
"""
855
- self ._info_proto .data_source_accesses .append (
856
- dataset_info_pb2 .DataSourceAccess (
857
- access_timestamp_ms = _now_in_milliseconds (),
858
- tfds_dataset = dataset_info_pb2 .TfdsDatasetReference (
859
- name = dataset_reference .dataset_name ,
860
- config = dataset_reference .config ,
861
- version = str (dataset_reference .version ),
862
- data_dir = os .fspath (dataset_reference .data_dir ),
863
- ds_namespace = dataset_reference .namespace ,
864
- ),
865
- url = dataset_info_pb2 .Url (url = url ),
866
- )
855
+ add_tfds_data_source_access (
856
+ dataset_info_proto = self ._info_proto ,
857
+ dataset_reference = dataset_reference ,
858
+ url = url ,
867
859
)
868
860
869
861
def initialize_from_bucket (self ) -> None :
@@ -1130,6 +1122,22 @@ def get_dataset_feature_statistics(builder, split):
1130
1122
return statistics .datasets [0 ], schema
1131
1123
1132
1124
1125
+ def get_dataset_info_json (
1126
+ dataset_info_proto : dataset_info_pb2 .DatasetInfo ,
1127
+ ) -> str :
1128
+ return json_format .MessageToJson (dataset_info_proto , sort_keys = True )
1129
+
1130
+
1131
+ def write_dataset_info_proto (
1132
+ dataset_info_proto : dataset_info_pb2 .DatasetInfo ,
1133
+ dataset_info_dir : epath .PathLike ,
1134
+ ) -> None :
1135
+ """Writes the dataset info proto to the given path."""
1136
+ dataset_info_dir = epath .Path (dataset_info_dir )
1137
+ json_str = get_dataset_info_json (dataset_info_proto )
1138
+ dataset_info_path (dataset_info_dir ).write_text (json_str )
1139
+
1140
+
1133
1141
def read_from_json (path : epath .PathLike ) -> dataset_info_pb2 .DatasetInfo :
1134
1142
"""Read JSON-formatted proto into DatasetInfo proto.
1135
1143
@@ -1308,6 +1316,36 @@ def supports_file_format(
1308
1316
return file_format in available_file_formats (dataset_info_proto )
1309
1317
1310
1318
1319
+ def get_split_dict_from_proto (
1320
+ dataset_info_proto : dataset_info_pb2 .DatasetInfo ,
1321
+ data_dir : epath .PathLike ,
1322
+ file_format : str | file_adapters .FileFormat | None = None ,
1323
+ ) -> splits_lib .SplitDict :
1324
+ """Returns the split dict with all split infos from the given dataset.
1325
+
1326
+ Args:
1327
+ dataset_info_proto: the proto with the dataset info and split infos.
1328
+ data_dir: the directory where the data is stored.
1329
+ file_format: the file format for which to get the split dict. If the file
1330
+ format is not specified, the file format from the dataset info proto is
1331
+ used.
1332
+ """
1333
+ if file_format :
1334
+ file_format = file_adapters .FileFormat (file_format )
1335
+ else :
1336
+ file_format = file_adapters .FileFormat (dataset_info_proto .file_format )
1337
+
1338
+ filename_template = naming .ShardedFileTemplate (
1339
+ dataset_name = dataset_info_proto .name ,
1340
+ data_dir = epath .Path (data_dir ),
1341
+ filetype_suffix = file_format .file_suffix ,
1342
+ )
1343
+ return splits_lib .SplitDict .from_proto (
1344
+ repeated_split_infos = dataset_info_proto .splits ,
1345
+ filename_template = filename_template ,
1346
+ )
1347
+
1348
+
1311
1349
def get_split_info_from_proto (
1312
1350
dataset_info_proto : dataset_info_pb2 .DatasetInfo ,
1313
1351
split_name : str ,
@@ -1328,22 +1366,40 @@ def get_split_info_from_proto(
1328
1366
f"File format { file_format .value } does not match available dataset file"
1329
1367
f" formats: { sorted (available_format )} ."
1330
1368
)
1331
- for split_info in dataset_info_proto .splits :
1332
- if split_info .name == split_name :
1333
- filename_template = naming .ShardedFileTemplate (
1334
- dataset_name = dataset_info_proto .name ,
1335
- data_dir = epath .Path (data_dir ),
1336
- filetype_suffix = file_format .file_suffix ,
1337
- )
1338
- # Override the default file name template if it was set.
1339
- if split_info .filepath_template :
1340
- filename_template = filename_template .replace (
1341
- template = split_info .filepath_template
1342
- )
1343
- return splits_lib .SplitInfo .from_proto (
1344
- proto = split_info , filename_template = filename_template
1369
+
1370
+ splits_dict = get_split_dict_from_proto (
1371
+ dataset_info_proto = dataset_info_proto ,
1372
+ data_dir = data_dir ,
1373
+ file_format = file_format ,
1374
+ )
1375
+ return splits_dict .get (split_name )
1376
+
1377
+
1378
+ def add_tfds_data_source_access (
1379
+ dataset_info_proto : dataset_info_pb2 .DatasetInfo ,
1380
+ dataset_reference : naming .DatasetReference ,
1381
+ url : str | None = None ,
1382
+ ) -> None :
1383
+ """Records that the given query was used to generate this dataset.
1384
+
1385
+ Args:
1386
+ dataset_info_proto: the proto with the dataset info to update.
1387
+ dataset_reference: the dataset reference to record.
1388
+ url: a URL referring to the TFDS dataset.
1389
+ """
1390
+ dataset_info_proto .data_source_accesses .append (
1391
+ dataset_info_pb2 .DataSourceAccess (
1392
+ access_timestamp_ms = _now_in_milliseconds (),
1393
+ tfds_dataset = dataset_info_pb2 .TfdsDatasetReference (
1394
+ name = dataset_reference .dataset_name ,
1395
+ config = dataset_reference .config ,
1396
+ version = str (dataset_reference .version ),
1397
+ data_dir = os .fspath (dataset_reference .data_dir ),
1398
+ ds_namespace = dataset_reference .namespace ,
1399
+ ),
1400
+ url = dataset_info_pb2 .Url (url = url ),
1345
1401
)
1346
- return None
1402
+ )
1347
1403
1348
1404
1349
1405
class MetadataDict (Metadata , dict ):
0 commit comments