@@ -723,6 +723,7 @@ def dummy_croissant_file(
723723 entries : Sequence [dict [str , Any ]] | None = None ,
724724 raw_data_filename : epath .PathLike = 'raw_data.jsonl' ,
725725 croissant_filename : epath .PathLike = 'croissant.json' ,
726+ split_names : Sequence [str ] | None = None ,
726727) -> Iterator [epath .Path ]:
727728 """Yields temporary path to a dummy Croissant file.
728729
@@ -732,13 +733,29 @@ def dummy_croissant_file(
732733 Args:
733734 dataset_name: The name of the dataset.
734735 entries: A list of dictionaries representing the dataset's entries. Each
735- dictionary should contain an 'index' and a 'text' key. If None, the
736- function will create two entries with indices 0 and 1 and dummy text.
737- raw_data_filename: Filename of the raw data file.
736+ dictionary should contain an 'index', a 'text', and a `split` key. If
737+ None, the function will create two entries with indices 0 and 1 and dummy
738+ text, and with the first entry belonging to the split `train` and the
739+ second to `test`.
740+ raw_data_filename: Filename of the raw data file. If `split_names` is True,
741+ the function will create a raw data file for each split, including the
742+ split name before the file extension.
738743 croissant_filename: Filename of the Croissant JSON-LD file.
744+ split_names: A list of split names to populate the split record set with. If
745+ split_names are defined, they must match the `split` key in the entries.
746+ If None, the function will create a split record set with the default
747+ split names `train` and `test`. If `split_names` is defined, the `split`
748+ key in the entries must match one of the split names.
739749 """
740750 if entries is None :
741- entries = [{'index' : i , 'text' : f'Dummy example { i } ' } for i in range (2 )]
751+ entries = [
752+ {
753+ 'index' : i ,
754+ 'text' : f'Dummy example { i } ' ,
755+ 'split' : 'train' if i % 2 == 0 else 'test' ,
756+ }
757+ for i in range (2 )
758+ ]
742759
743760 fields = [
744761 mlc .Field (
@@ -771,29 +788,82 @@ def dummy_croissant_file(
771788 fields = fields ,
772789 )
773790 ]
791+ if split_names :
792+ record_sets [0 ].fields .append (
793+ mlc .Field (
794+ id = 'jsonl/split' ,
795+ name = 'jsonl/split' ,
796+ description = 'The dummy split.' ,
797+ data_types = mlc .DataType .TEXT ,
798+ source = mlc .Source (
799+ file_object = 'raw_data' ,
800+ extract = mlc .Extract (file_property = 'fullpath' ),
801+ transforms = [mlc .Transform (regex = '.*(.+).+jsonl$' )],
802+ ),
803+ references = mlc .Source (field = 'split/name' ),
804+ ),
805+ )
806+ record_sets .append (
807+ mlc .RecordSet (
808+ id = 'split' ,
809+ name = 'split' ,
810+ key = 'split/name' ,
811+ data_types = [mlc .DataType .SPLIT ],
812+ description = 'Dummy split.' ,
813+ fields = [
814+ mlc .Field (
815+ id = 'split/name' ,
816+ name = 'split/name' ,
817+ description = 'The dummy split name.' ,
818+ data_types = mlc .DataType .TEXT ,
819+ )
820+ ],
821+ data = [{'split/name' : split_name } for split_name in split_names ],
822+ )
823+ )
774824
775825 with tempfile .TemporaryDirectory () as tempdir :
776826 tempdir = epath .Path (tempdir )
777827
778828 # Write raw examples to tempdir/data.
779829 raw_data_dir = tempdir / 'data'
780830 raw_data_dir .mkdir ()
781- raw_data_file = raw_data_dir / raw_data_filename
782- raw_data_file .write_text ('\n ' .join (map (json .dumps , entries )))
783-
784- # Get the actual raw file's hash, set distribution and metadata.
785- raw_data_file_content = raw_data_file .read_text ()
786- sha256 = hashlib .sha256 (raw_data_file_content .encode ()).hexdigest ()
787- distribution = [
788- mlc .FileObject (
789- id = 'raw_data' ,
790- name = 'raw_data' ,
791- description = 'File with the data.' ,
792- encoding_format = 'application/jsonlines' ,
793- content_url = f'data/{ raw_data_filename } ' ,
794- sha256 = sha256 ,
795- ),
796- ]
831+ if split_names :
832+ parts = str (raw_data_filename ).split ('.' )
833+ file_name , extension = '.' .join (parts [:- 1 ]), parts [- 1 ]
834+ for split_name in split_names :
835+ raw_data_file = raw_data_dir / (
836+ file_name + '_' + split_name + '.' + extension
837+ )
838+ split_entries = [
839+ entry for entry in entries if entry ['split' ] == split_name
840+ ]
841+ raw_data_file .write_text ('\n ' .join (map (json .dumps , split_entries )))
842+ distribution = [
843+ mlc .FileSet (
844+ id = 'raw_data' ,
845+ name = 'raw_data' ,
846+ description = 'Files with the data.' ,
847+ encoding_format = 'application/jsonlines' ,
848+ includes = f'data/{ file_name } *.{ extension } ' ,
849+ ),
850+ ]
851+ else :
852+ raw_data_file = raw_data_dir / raw_data_filename
853+ raw_data_file .write_text ('\n ' .join (map (json .dumps , entries )))
854+ # Get the actual raw file's hash, set distribution and metadata.
855+ raw_data_file_content = raw_data_file .read_text ()
856+ sha256 = hashlib .sha256 (raw_data_file_content .encode ()).hexdigest ()
857+ distribution = [
858+ mlc .FileObject (
859+ id = 'raw_data' ,
860+ name = 'raw_data' ,
861+ description = 'File with the data.' ,
862+ encoding_format = 'application/jsonlines' ,
863+ content_url = f'data/{ raw_data_filename } ' ,
864+ sha256 = sha256 ,
865+ ),
866+ ]
797867 dummy_metadata = mlc .Metadata (
798868 name = dataset_name ,
799869 description = 'Dummy description.' ,
@@ -807,7 +877,6 @@ def dummy_croissant_file(
807877 version = '1.2.0' ,
808878 license = 'Public' ,
809879 )
810-
811880 # Write Croissant JSON-LD to tempdir.
812881 croissant_file = tempdir / croissant_filename
813882 croissant_file .write_text (json .dumps (dummy_metadata .to_json (), indent = 2 ))
0 commit comments