From aedf90cbb560918151292348e53ff729b08bee21 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Sat, 5 Apr 2025 11:29:01 +0100 Subject: [PATCH] Remove explict schema data --- CHANGELOG.md | 7 ++++ bio2zarr/icf.py | 32 +++++++-------- bio2zarr/plink.py | 27 ++++++++---- bio2zarr/vcz.py | 102 ++++++++++++++++++++++++++++++++++------------ tests/test_icf.py | 47 ++++++++++----------- tests/test_vcz.py | 62 ---------------------------- 6 files changed, 137 insertions(+), 140 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20fd17f7..e3d08519 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 0.1.6 2025-0X-XX + +Breaking changes + +- Remove explicit sample, contig and filter lists from the schema. + Existing ICFs will need to be recreated. (#343) + # 0.1.5 2025-03-31 - Add support for merging contig IDs across multiple VCFs (#335) diff --git a/bio2zarr/icf.py b/bio2zarr/icf.py index 370432e2..8dc25c59 100644 --- a/bio2zarr/icf.py +++ b/bio2zarr/icf.py @@ -877,9 +877,9 @@ def convert_local_allele_field_types(fields): return [*fields, la] -class IntermediateColumnarFormat(collections.abc.Mapping): +class IntermediateColumnarFormat(vcz.Source): def __init__(self, path): - self.path = pathlib.Path(path) + self._path = pathlib.Path(path) # TODO raise a more informative error here telling people this # directory is either a WIP or the wrong format. with open(self.path / "metadata.json") as f: @@ -902,20 +902,11 @@ def __init__(self, path): def __repr__(self): return ( - f"IntermediateColumnarFormat(fields={len(self)}, " + f"IntermediateColumnarFormat(fields={len(self.fields)}, " f"partitions={self.num_partitions}, " f"records={self.num_records}, path={self.path})" ) - def __getitem__(self, key): - return self.fields[key] - - def __iter__(self): - return iter(self.fields) - - def __len__(self): - return len(self.fields) - def summary_table(self): data = [] for name, icf_field in self.fields.items(): @@ -934,6 +925,10 @@ def summary_table(self): data.append(d) return data + @property + def path(self): + return self._path + @property def num_records(self): return self.metadata.num_records @@ -944,7 +939,15 @@ def num_partitions(self): @property def samples(self): - return [sample.id for sample in self.metadata.samples] + return self.metadata.samples + + @property + def contigs(self): + return self.metadata.contigs + + @property + def filters(self): + return self.metadata.filters @property def num_samples(self): @@ -1037,9 +1040,6 @@ def generate_schema( samples_chunk_size=samples_chunk_size, variants_chunk_size=variants_chunk_size, fields=[], - samples=self.metadata.samples, - contigs=self.metadata.contigs, - filters=self.metadata.filters, ) logger.info( diff --git a/bio2zarr/plink.py b/bio2zarr/plink.py index 6789de92..17b78513 100644 --- a/bio2zarr/plink.py +++ b/bio2zarr/plink.py @@ -10,14 +10,26 @@ logger = logging.getLogger(__name__) -class PlinkFormat: +class PlinkFormat(vcz.Source): def __init__(self, path): - self.path = path + self._path = pathlib.Path(path) self.bed = bed_reader.open_bed(path, num_threads=1, count_A1=False) - self.num_records = self.bed.sid_count - self.samples = list(self.bed.iid) - self.num_samples = len(self.samples) - self.root_attrs = {} + + @property + def path(self): + return self._path + + @property + def num_records(self): + return self.bed.sid_count + + @property + def samples(self): + return [vcz.Sample(id=sample) for sample in self.bed.iid] + + @property + def num_samples(self): + return len(self.samples) def iter_alleles(self, start, stop, num_alleles): ref_field = self.bed.allele_1 @@ -62,9 +74,6 @@ def generate_schema( samples_chunk_size=samples_chunk_size, variants_chunk_size=variants_chunk_size, fields=[], - samples=[vcz.Sample(id=sample) for sample in self.bed.iid], - contigs=[], - filters=[], ) logger.info( diff --git a/bio2zarr/vcz.py b/bio2zarr/vcz.py index a3c306ca..12dbd2ff 100644 --- a/bio2zarr/vcz.py +++ b/bio2zarr/vcz.py @@ -1,3 +1,4 @@ +import abc import dataclasses import json import logging @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) -ZARR_SCHEMA_FORMAT_VERSION = "0.4" +ZARR_SCHEMA_FORMAT_VERSION = "0.5" DEFAULT_ZARR_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7) _fixed_field_descriptions = { @@ -28,6 +29,62 @@ } +class Source(abc.ABC): + @property + @abc.abstractmethod + def path(self): + pass + + @property + @abc.abstractmethod + def num_records(self): + pass + + @property + @abc.abstractmethod + def num_samples(self): + pass + + @property + @abc.abstractmethod + def samples(self): + pass + + @property + def contigs(self): + return None + + @property + def filters(self): + return None + + @property + def root_attrs(self): + return {} + + @abc.abstractmethod + def iter_alleles(self, start, stop, num_alleles): + pass + + @abc.abstractmethod + def iter_genotypes(self, start, stop, num_alleles): + pass + + def iter_id(self, start, stop): + return + + def iter_contig(self, start, stop): + return + + @abc.abstractmethod + def iter_field(self, field_name, shape, start, stop): + pass + + @abc.abstractmethod + def generate_schema(self, variants_chunk_size, samples_chunk_size, local_alleles): + pass + + @dataclasses.dataclass class ZarrArraySpec: name: str @@ -182,25 +239,16 @@ class VcfZarrSchema(core.JsonDataclass): format_version: str samples_chunk_size: int variants_chunk_size: int - samples: list - contigs: list - filters: list fields: list def __init__( self, format_version: str, - samples: list, - contigs: list, - filters: list, fields: list, variants_chunk_size: int = None, samples_chunk_size: int = None, ): self.format_version = format_version - self.samples = samples - self.contigs = contigs - self.filters = filters self.fields = fields if variants_chunk_size is None: variants_chunk_size = 1000 @@ -238,9 +286,6 @@ def fromdict(d): f"{d['format_version']} != {ZARR_SCHEMA_FORMAT_VERSION}" ) ret = VcfZarrSchema(**d) - ret.samples = [Sample(**sd) for sd in d["samples"]] - ret.contigs = [Contig(**sd) for sd in d["contigs"]] - ret.filters = [Filter(**sd) for sd in d["filters"]] ret.fields = [ZarrArraySpec(**sd) for sd in d["fields"]] return ret @@ -474,8 +519,10 @@ def init( # Doing this synchronously - this is fine surely self.encode_samples(root) - self.encode_filter_id(root) - self.encode_contig_id(root) + if self.source.filters is not None: + self.encode_filter_id(root) + if self.source.contigs is not None: + self.encode_contigs(root) self.wip_path.mkdir() self.arrays_path.mkdir() @@ -502,12 +549,11 @@ def init( ) def encode_samples(self, root): - if [s.id for s in self.schema.samples] != self.source.samples: - raise ValueError("Subsetting or reordering samples not supported currently") + samples = self.source.samples array = root.array( "sample_id", - data=[sample.id for sample in self.schema.samples], - shape=len(self.schema.samples), + data=[sample.id for sample in samples], + shape=len(samples), dtype="str", compressor=DEFAULT_ZARR_COMPRESSOR, chunks=(self.schema.samples_chunk_size,), @@ -515,20 +561,21 @@ def encode_samples(self, root): array.attrs["_ARRAY_DIMENSIONS"] = ["samples"] logger.debug("Samples done") - def encode_contig_id(self, root): + def encode_contigs(self, root): + contigs = self.source.contigs array = root.array( "contig_id", - data=[contig.id for contig in self.schema.contigs], - shape=len(self.schema.contigs), + data=[contig.id for contig in contigs], + shape=len(contigs), dtype="str", compressor=DEFAULT_ZARR_COMPRESSOR, ) array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] - if all(contig.length is not None for contig in self.schema.contigs): + if all(contig.length is not None for contig in contigs): array = root.array( "contig_length", - data=[contig.length for contig in self.schema.contigs], - shape=len(self.schema.contigs), + data=[contig.length for contig in contigs], + shape=len(contigs), dtype=np.int64, compressor=DEFAULT_ZARR_COMPRESSOR, ) @@ -537,10 +584,11 @@ def encode_contig_id(self, root): def encode_filter_id(self, root): # TODO need a way to store description also # https://github.com/sgkit-dev/vcf-zarr-spec/issues/19 + filters = self.source.filters array = root.array( "filter_id", - data=[filt.id for filt in self.schema.filters], - shape=len(self.schema.filters), + data=[filt.id for filt in filters], + shape=len(filters), dtype="str", compressor=DEFAULT_ZARR_COMPRESSOR, ) diff --git a/tests/test_icf.py b/tests/test_icf.py index 3297b3f6..2e9275d8 100644 --- a/tests/test_icf.py +++ b/tests/test_icf.py @@ -52,11 +52,6 @@ def test_summary_table(self, icf): def test_inspect(self, icf): assert icf.summary_table() == icf_mod.inspect(icf.path) - def test_mapping_methods(self, icf): - assert len(icf) == len(self.fields) - assert icf["ALT"] is icf.fields["ALT"] - assert list(iter(icf)) == list(iter(icf)) - def test_num_partitions(self, icf): assert icf.num_partitions == 3 @@ -65,13 +60,13 @@ def test_num_records(self, icf): def test_POS(self, icf): nt.assert_array_equal( - [v[0] for v in icf["POS"].values], + [v[0] for v in icf.fields["POS"].values], [111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10], ) def test_REF(self, icf): ref = ["A", "A", "G", "T", "A", "T", "G", "T", "AC"] - assert icf["REF"].values == ref + assert icf.fields["REF"].values == ref def test_ALT(self, icf): alt = [ @@ -85,10 +80,10 @@ def test_ALT(self, icf): [], ["A", "ATG", "C"], ] - assert [list(v) for v in icf["ALT"].values] == alt + assert [list(v) for v in icf.fields["ALT"].values] == alt def test_INFO_NS(self, icf): - assert icf["INFO/NS"].values == [None, None, 3, 3, 2, 3, 3, None, None] + assert icf.fields["INFO/NS"].values == [None, None, 3, 3, 2, 3, 3, None, None] class TestWithGtHeaderNoGenotypes: @@ -100,7 +95,7 @@ def icf(self, tmp_path_factory): return icf_mod.explode(out, [self.data_path]) def test_gts(self, icf): - values = icf["FORMAT/GT"].values + values = icf.fields["FORMAT/GT"].values assert values == [None] * icf.num_records @@ -248,28 +243,28 @@ def test_info_schemas(self, schema, name, dtype, shape, dimensions): assert v.dimensions == dimensions def test_info_string1(self, icf): - non_missing = [v for v in icf["INFO/IS1"].values if v is not None] + non_missing = [v for v in icf.fields["INFO/IS1"].values if v is not None] assert non_missing[0] == "bc" assert non_missing[1] == "." def test_info_char1(self, icf): - non_missing = [v for v in icf["INFO/IC1"].values if v is not None] + non_missing = [v for v in icf.fields["INFO/IC1"].values if v is not None] assert non_missing[0] == "f" assert non_missing[1] == "." def test_info_string2(self, icf): - non_missing = [v for v in icf["INFO/IS2"].values if v is not None] + non_missing = [v for v in icf.fields["INFO/IS2"].values if v is not None] nt.assert_array_equal(non_missing[0], ["hij", "d"]) nt.assert_array_equal(non_missing[1], [".", "d"]) nt.assert_array_equal(non_missing[2], ["hij", "."]) nt.assert_array_equal(non_missing[3], [".", "."]) def test_format_string1(self, icf): - non_missing = [v for v in icf["FORMAT/FS1"].values if v is not None] + non_missing = [v for v in icf.fields["FORMAT/FS1"].values if v is not None] nt.assert_array_equal(non_missing[0], [["bc"], ["."]]) def test_format_string2(self, icf): - non_missing = [v for v in icf["FORMAT/FS2"].values if v is not None] + non_missing = [v for v in icf.fields["FORMAT/FS2"].values if v is not None] nt.assert_array_equal(non_missing[0], [["bc", "op"], [".", "op"]]) nt.assert_array_equal(non_missing[1], [["bc", "."], [".", "."]]) @@ -344,7 +339,7 @@ def test_missing_field(self, tmp_path): shutil.rmtree(icf_path / "POS") icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(FileNotFoundError): - icf["POS"].values # noqa B018 + icf.fields["POS"].values # noqa B018 def test_missing_chunk_index(self, tmp_path): icf_path = tmp_path / "icf" @@ -354,7 +349,7 @@ def test_missing_chunk_index(self, tmp_path): chunk_index_path.unlink() icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(FileNotFoundError): - icf["POS"].values # noqa B018 + icf.fields["POS"].values # noqa B018 def test_missing_chunk_file(self, tmp_path): icf_path = tmp_path / "icf" @@ -364,7 +359,7 @@ def test_missing_chunk_file(self, tmp_path): chunk_file.unlink() icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(FileNotFoundError): - icf["POS"].values # noqa B018 + icf.fields["POS"].values # noqa B018 def test_empty_chunk_file(self, tmp_path): icf_path = tmp_path / "icf" @@ -375,7 +370,7 @@ def test_empty_chunk_file(self, tmp_path): pass icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(RuntimeError, match="blosc"): - icf["POS"].values # noqa B018 + icf.fields["POS"].values # noqa B018 # Chunk file is 187 long @pytest.mark.parametrize("length", [10, 100, 185]) @@ -391,7 +386,7 @@ def test_truncated_chunk_file(self, tmp_path, length): icf = icf_mod.IntermediateColumnarFormat(icf_path) # Either Blosc or pickling errors happen here with pytest.raises((RuntimeError, pickle.UnpicklingError)): - icf["POS"].values # noqa B018 + icf.fields["POS"].values # noqa B018 def test_chunk_incorrect_length(self, tmp_path): icf_path = tmp_path / "icf" @@ -408,9 +403,9 @@ def test_chunk_incorrect_length(self, tmp_path): f.write(compressor.encode(pkl)) icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(ValueError, match="Corruption detected"): - icf["POS"].values # noqa B018 + icf.fields["POS"].values # noqa B018 with pytest.raises(ValueError, match="Corruption detected"): - list(icf["POS"].iter_values(0, 9)) + list(icf.fields["POS"].iter_values(0, 9)) class TestSlicing: @@ -429,7 +424,7 @@ def test_repr(self, icf): ) def test_pos_repr(self, icf): - assert repr(icf["POS"]).startswith( + assert repr(icf.fields["POS"]).startswith( "IntermediateColumnarFormatField(name=POS, " "partition_chunks=[8, 8, 8, 8, 8], path=" ) @@ -440,14 +435,14 @@ def test_partition_record_index(self, icf): ) def test_pos_values(self, icf): - field = icf["POS"] + field = icf.fields["POS"] pos = np.array([v[0] for v in field.values]) # Check the actual values here to make sure other tests make sense actual = np.hstack([1 + np.arange(933) for _ in range(5)]) nt.assert_array_equal(pos, actual) def test_pos_chunk_records(self, icf): - pos = icf["POS"] + pos = icf.fields["POS"] for j in range(pos.num_partitions): a = pos.chunk_record_index(j) nt.assert_array_equal(a, [0, 118, 236, 354, 472, 590, 708, 826, 933]) @@ -478,7 +473,7 @@ def test_pos_chunk_records(self, icf): ], ) def test_slice(self, icf, start, stop): - field = icf["POS"] + field = icf.fields["POS"] pos = np.array(field.values) pos_slice = np.array(list(field.iter_values(start, stop))) nt.assert_array_equal(pos[start:stop], pos_slice) diff --git a/tests/test_vcz.py b/tests/test_vcz.py index aafbcde7..49f863a7 100644 --- a/tests/test_vcz.py +++ b/tests/test_vcz.py @@ -205,12 +205,6 @@ def test_generated_no_fields(self, icf_path): schema.fields.clear() self.assert_json_round_trip(schema) - def test_generated_no_samples(self, icf_path): - icf = icf_mod.IntermediateColumnarFormat(icf_path) - schema = icf.generate_schema() - schema.samples.clear() - self.assert_json_round_trip(schema) - def test_generated_change_dtype(self, icf_path): icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = icf.generate_schema() @@ -337,23 +331,6 @@ def test_chunk_size(self, schema): assert schema.samples_chunk_size == 10000 assert schema.variants_chunk_size == 1000 - def test_samples(self, schema): - assert schema.asdict()["samples"] == [ - {"id": s} for s in ["NA00001", "NA00002", "NA00003"] - ] - - def test_contigs(self, schema): - assert schema.asdict()["contigs"] == [ - {"id": s, "length": None} for s in ["19", "20", "X"] - ] - - def test_filters(self, schema): - assert schema.asdict()["filters"] == [ - {"id": "PASS", "description": "All filters passed"}, - {"id": "s50", "description": "Less than 50% of samples have data"}, - {"id": "q10", "description": "Quality below 10"}, - ] - def test_variant_contig(self, schema): assert get_field_dict(schema, "variant_contig") == { "name": "variant_contig", @@ -505,18 +482,6 @@ class TestVcfDescriptions: def test_fields(self, schema, field, description): assert schema.field_map()[field].description == description - @pytest.mark.parametrize( - ("filt", "description"), - [ - ("PASS", "All filters passed"), - ("s50", "Less than 50% of samples have data"), - ("q10", "Quality below 10"), - ], - ) - def test_filters(self, schema, filt, description): - d = {f.id: f.description for f in schema.filters} - assert d[filt] == description - class TestVcfZarrWriterExample: arrays = ( @@ -689,33 +654,6 @@ def test_call_fields(self, tmp_path, field): icf_mod.explode(tmp_path / "x.icf", [tmp_path / "test.vcf.gz"]) -class TestBadSchemaChanges: - # [{'id': 'NA00001'}, {'id': 'NA00002'}, {'id': 'NA00003'}], - @pytest.mark.parametrize( - "samples", - [ - [], - [{"id": "NA00001"}, {"id": "NA00003"}], - [{"id": "NA00001"}, {"id": "NA00002"}, {"id": "NA00004"}], - [ - {"id": "NA00001"}, - {"id": "NA00002"}, - {"id": "NA00003"}, - {"id": "NA00004"}, - ], - [{"id": "NA00001"}, {"id": "NA00003"}, {"id": "NA00002"}], - ], - ) - def test_removed_samples(self, tmp_path, schema, icf_path, samples): - d = schema.asdict() - d["samples"] = samples - schema_path = tmp_path / "schema.json" - with open(schema_path, "w") as f: - json.dump(d, f) - with pytest.raises(ValueError, match="Subsetting or reordering samples"): - icf_mod.encode(icf_path, tmp_path / "z", schema_path=schema_path) - - class TestInspect: def test_icf(self, icf_path): df = pd.DataFrame(icf_mod.inspect(icf_path))