diff --git a/CHANGELOG.md b/CHANGELOG.md index 11c3c261..8241ad4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ - Add contigs to plink output (#344) +- Add variant_length and indexing to plink output (#382) + Breaking changes - Remove explicit sample, contig and filter lists from the schema. diff --git a/bio2zarr/plink.py b/bio2zarr/plink.py index 76cad279..112f77aa 100644 --- a/bio2zarr/plink.py +++ b/bio2zarr/plink.py @@ -63,7 +63,7 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): gt[bed_chunk[i] == 2] = 1 gt[bed_chunk[i] == 1, 0] = 1 - yield alleles, (gt, phased) + yield vcz.VariantData(max(len(a) for a in alleles), alleles, gt, phased) def generate_schema( self, @@ -112,6 +112,13 @@ def generate_schema( dimensions=["variants", "alleles"], description=None, ), + vcz.ZarrArraySpec( + source=None, + name="variant_length", + dtype="i4", + dimensions=["variants"], + description="Length of each variant", + ), vcz.ZarrArraySpec( name="variant_contig", dtype=core.min_int_dtype(0, len(np.unique(self.bed.chromosome))), diff --git a/bio2zarr/tskit.py b/bio2zarr/tskit.py index 112e901a..a5241389 100644 --- a/bio2zarr/tskit.py +++ b/bio2zarr/tskit.py @@ -97,21 +97,23 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): left=self.positions[start], right=self.positions[stop] if stop < self.num_records else None, samples=self.tskit_samples, + copy=False, ): gt = np.full(shape, constants.INT_FILL, dtype=np.int8) alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") + variant_length = 0 for i, allele in enumerate(variant.alleles): # None is returned by tskit in the case of a missing allele if allele is None: continue assert i < num_alleles alleles[i] = allele - + variant_length = max(variant_length, len(allele)) gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[ self.genotype_indices ] - yield alleles, (gt, phased) + yield vcz.VariantData(variant_length, alleles, gt, phased) def generate_schema( self, @@ -164,6 +166,16 @@ def generate_schema( min_position = np.min(self.ts.sites_position) max_position = np.max(self.ts.sites_position) + tables = self.ts.tables + ancestral_state_offsets = tables.sites.ancestral_state_offset + derived_state_offsets = tables.mutations.derived_state_offset + ancestral_lengths = ancestral_state_offsets[1:] - ancestral_state_offsets[:-1] + derived_lengths = derived_state_offsets[1:] - derived_state_offsets[:-1] + max_variant_length = max( + np.max(ancestral_lengths) if len(ancestral_lengths) > 0 else 0, + np.max(derived_lengths) if len(derived_lengths) > 0 else 0, + ) + array_specs = [ vcz.ZarrArraySpec( source="position", @@ -179,6 +191,13 @@ def generate_schema( dimensions=["variants", "alleles"], description="Alleles for each variant", ), + vcz.ZarrArraySpec( + source=None, + name="variant_length", + dtype=core.min_int_dtype(0, max_variant_length), + dimensions=["variants"], + description="Length of each variant", + ), vcz.ZarrArraySpec( source=None, name="variant_contig", diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index e5f26d95..252dbd27 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -1040,14 +1040,19 @@ def iter_genotypes(self, shape, start, stop): yield sanitised_genotypes, sanitised_phased def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): + variant_lengths = self.fields["rlen"].iter_values(start, stop) if self.gt_field is None or shape is None: - for alleles in self.iter_alleles(start, stop, num_alleles): - yield alleles, (None, None) + for variant_length, alleles in zip( + variant_lengths, self.iter_alleles(start, stop, num_alleles) + ): + yield vcz.VariantData(variant_length, alleles, None, None) else: - yield from zip( + for variant_length, alleles, (gt, phased) in zip( + variant_lengths, self.iter_alleles(start, stop, num_alleles), self.iter_genotypes(shape, start, stop), - ) + ): + yield vcz.VariantData(variant_length, alleles, gt, phased) def generate_schema( self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None @@ -1121,6 +1126,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)): compressor=compressor, ) + name_map = {field.full_name: field for field in self.metadata.fields} array_specs = [ fixed_field_spec( name="variant_contig", @@ -1136,6 +1142,11 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)): dtype="O", dimensions=["variants", "alleles"], ), + fixed_field_spec( + name="variant_length", + dtype=name_map["rlen"].smallest_dtype(), + dimensions=["variants"], + ), fixed_field_spec( name="variant_id", dtype="O", @@ -1145,14 +1156,12 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)): dtype="bool", ), ] - name_map = {field.full_name: field for field in self.metadata.fields} - # Only three of the fixed fields have a direct one-to-one mapping. + # Only two of the fixed fields have a direct one-to-one mapping. array_specs.extend( [ spec_from_field(name_map["QUAL"], array_name="variant_quality"), spec_from_field(name_map["POS"], array_name="variant_position"), - spec_from_field(name_map["rlen"], array_name="variant_length"), ] ) array_specs.extend( diff --git a/bio2zarr/vcz.py b/bio2zarr/vcz.py index 452ef5e2..634947ca 100644 --- a/bio2zarr/vcz.py +++ b/bio2zarr/vcz.py @@ -37,6 +37,16 @@ } +@dataclasses.dataclass +class VariantData: + """Represents variant data returned by iter_alleles_and_genotypes.""" + + variant_length: int + alleles: np.ndarray + genotypes: np.ndarray + phased: np.ndarray + + class Source(abc.ABC): @property @abc.abstractmethod @@ -794,6 +804,7 @@ def encode_array_partition(self, array_spec, partition_index): def encode_alleles_and_genotypes_partition(self, partition_index): partition = self.metadata.partitions[partition_index] alleles = self.init_partition_array(partition_index, "variant_allele") + variant_lengths = self.init_partition_array(partition_index, "variant_length") has_gt = self.has_genotypes() shape = None if has_gt: @@ -802,18 +813,21 @@ def encode_alleles_and_genotypes_partition(self, partition_index): partition_index, "call_genotype_phased" ) shape = gt.buff.shape[1:] - for alleles_value, (genotype, phased) in self.source.iter_alleles_and_genotypes( + for variant_data in self.source.iter_alleles_and_genotypes( partition.start, partition.stop, shape, alleles.array.shape[1] ): j_alleles = alleles.next_buffer_row() - alleles.buff[j_alleles] = alleles_value + alleles.buff[j_alleles] = variant_data.alleles + j_variant_length = variant_lengths.next_buffer_row() + variant_lengths.buff[j_variant_length] = variant_data.variant_length if has_gt: j = gt.next_buffer_row() - gt.buff[j] = genotype + gt.buff[j] = variant_data.genotypes j_phased = gt_phased.next_buffer_row() - gt_phased.buff[j_phased] = phased + gt_phased.buff[j_phased] = variant_data.phased self.finalise_partition_array(partition_index, alleles) + self.finalise_partition_array(partition_index, variant_lengths) if has_gt: self.finalise_partition_array(partition_index, gt) self.finalise_partition_array(partition_index, gt_phased) @@ -1103,14 +1117,16 @@ def __init__(self, path): def create_index(self): """Create an index to support efficient region queries.""" root = zarr.open_group(store=self.path, mode="r+") - + print(list(root.keys())) if ( "variant_contig" not in root or "variant_position" not in root or "variant_length" not in root ): - logger.warning("Cannot create index: required arrays not found") - return + raise ValueError( + "Cannot create index: variant_contig, " + "variant_position and variant_length arrays are required" + ) contig = root["variant_contig"] pos = root["variant_position"] diff --git a/tests/data/plink/example.bim b/tests/data/plink/example.bim index a20a2bc3..4a7bb349 100644 --- a/tests/data/plink/example.bim +++ b/tests/data/plink/example.bim @@ -1,2 +1,2 @@ -1 1_10 0 10 A G -1 1_20 0 20 T C +1 1_10 0 10 A GG +1 1_20 0 20 TTT C diff --git a/tests/test_core.py b/tests/test_core.py index de596d29..55abc908 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -237,7 +237,7 @@ def test_examples(self, chunk_size, size, start, stop): # It works in CI on Linux, but it'll probably break at some point. # It's also necessary to update these numbers each time a new data # file gets added - ("tests/data", 5045029), + ("tests/data", 5045032), ("tests/data/vcf", 5018640), ("tests/data/vcf/sample.vcf.gz", 1089), ], diff --git a/tests/test_plink.py b/tests/test_plink.py index 21ae557a..6377776e 100644 --- a/tests/test_plink.py +++ b/tests/test_plink.py @@ -94,8 +94,8 @@ class TestExample: """ .bim file looks like this: - 1 1_10 0 10 A G - 1 1_20 0 20 T C + 1 1_10 0 10 A GG + 1 1_20 0 20 TTT C Definition: https://www.cog-genomics.org/plink/1.9/formats#bim Chromosome code (either an integer, or 'X'/'Y'/'XY'/'MT'; '0' @@ -121,7 +121,10 @@ def test_variant_position(self, ds): nt.assert_array_equal(ds.variant_position, [10, 20]) def test_variant_allele(self, ds): - nt.assert_array_equal(ds.variant_allele, [["A", "G"], ["T", "C"]]) + nt.assert_array_equal(ds.variant_allele, [["A", "GG"], ["TTT", "C"]]) + + def test_variant_length(self, ds): + nt.assert_array_equal(ds.variant_length, [2, 3]) def test_contig_id(self, ds): """Test that contig identifiers are correctly extracted and stored.""" @@ -266,6 +269,9 @@ def test_chunk_size( worker_processes=worker_processes, ) ds2 = sg.load_dataset(out) + # Drop the region_index as it is chunk dependent + ds = ds.drop_vars("region_index") + ds2 = ds2.drop_vars("region_index") xt.assert_equal(ds, ds2) # TODO check array chunks @@ -372,3 +378,9 @@ def test_genotypes(self, ds): def test_variant_position(self, ds): nt.assert_array_equal(ds.variant_position, [10, 20, 10, 10, 20, 10]) + + def test_variant_length(self, ds): + nt.assert_array_equal( + ds.variant_length, + [1, 1, 1, 1, 1, 1], + ) diff --git a/tests/test_ts.py b/tests/test_ts.py index 80fbe8d5..e56924c9 100644 --- a/tests/test_ts.py +++ b/tests/test_ts.py @@ -24,11 +24,11 @@ def test_simple_tree_sequence(self, tmp_path): tables.edges.add_row(left=0, right=100, parent=5, child=2) tables.edges.add_row(left=0, right=100, parent=5, child=3) site_id = tables.sites.add_row(position=10, ancestral_state="A") - tables.mutations.add_row(site=site_id, node=4, derived_state="T") - site_id = tables.sites.add_row(position=20, ancestral_state="C") + tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT") + site_id = tables.sites.add_row(position=20, ancestral_state="CCC") tables.mutations.add_row(site=site_id, node=5, derived_state="G") site_id = tables.sites.add_row(position=30, ancestral_state="G") - tables.mutations.add_row(site=site_id, node=0, derived_state="A") + tables.mutations.add_row(site=site_id, node=0, derived_state="AA") tables.sort() tree_sequence = tables.tree_sequence() tree_sequence.dump(tmp_path / "test.trees") @@ -54,7 +54,12 @@ def test_simple_tree_sequence(self, tmp_path): alleles = zroot["variant_allele"][:] assert alleles.shape == (3, 2) assert alleles.dtype == "O" - assert np.array_equal(alleles, [["A", "T"], ["C", "G"], ["G", "A"]]) + assert np.array_equal(alleles, [["A", "TTTT"], ["CCC", "G"], ["G", "AA"]]) + + lengths = zroot["variant_length"][:] + assert lengths.shape == (3,) + assert lengths.dtype == np.int8 + assert np.array_equal(lengths, [4, 3, 2]) genotypes = zroot["call_genotype"][:] assert genotypes.shape == (3, 2, 2) @@ -65,7 +70,7 @@ def test_simple_tree_sequence(self, tmp_path): phased = zroot["call_genotype_phased"][:] assert phased.shape == (3, 2) - assert phased.dtype == np.bool + assert phased.dtype == "bool" assert np.all(phased) contigs = zroot["contig_id"][:] @@ -83,15 +88,22 @@ def test_simple_tree_sequence(self, tmp_path): assert samples.dtype == "O" assert np.array_equal(samples, ["tsk_0", "tsk_1"]) + region_index = zroot["region_index"][:] + assert region_index.shape == (1, 6) + assert region_index.dtype == np.int8 + assert np.array_equal(region_index, [[0, 0, 10, 30, 31, 3]]) + assert set(zroot.array_keys()) == { "variant_position", "variant_allele", + "variant_length", "call_genotype", "call_genotype_phased", "call_genotype_mask", "contig_id", "variant_contig", "sample_id", + "region_index", } def test_missing_dependency(self): @@ -129,8 +141,8 @@ def simple_ts(self, tmp_path): tables.edges.add_row(left=0, right=100, parent=5, child=2) tables.edges.add_row(left=0, right=100, parent=5, child=3) site_id = tables.sites.add_row(position=10, ancestral_state="A") - tables.mutations.add_row(site=site_id, node=4, derived_state="T") - site_id = tables.sites.add_row(position=20, ancestral_state="C") + tables.mutations.add_row(site=site_id, node=4, derived_state="TT") + site_id = tables.sites.add_row(position=20, ancestral_state="CCC") tables.mutations.add_row(site=site_id, node=5, derived_state="G") site_id = tables.sites.add_row(position=30, ancestral_state="G") tables.mutations.add_row(site=site_id, node=0, derived_state="A") @@ -264,6 +276,7 @@ def test_schema_generation(self, simple_ts): field_names = [field.name for field in schema.fields] assert "variant_position" in field_names assert "variant_allele" in field_names + assert "variant_length" in field_names assert "variant_contig" in field_names assert "call_genotype" in field_names assert "call_genotype_phased" in field_names @@ -335,18 +348,22 @@ def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts): assert len(results) == 3 - for i, (alleles, (gt, phased)) in enumerate(results): + for i, variant_data in enumerate(results): if i == 0: - assert tuple(alleles) == ("A", "T") + assert variant_data.variant_length == 2 + assert np.array_equal(variant_data.alleles, ("A", "TT")) elif i == 1: - assert tuple(alleles) == ("C", "G") + assert variant_data.variant_length == 3 + assert np.array_equal(variant_data.alleles, ("CCC", "G")) elif i == 2: - assert tuple(alleles) == ("G", "A") + assert variant_data.variant_length == 1 + assert np.array_equal(variant_data.alleles, ("G", "A")) assert np.array_equal( - gt, expected_gts[i] - ), f"Mismatch at variant {i}, expected {expected_gts[i]}, got {gt}" - assert np.all(phased) + variant_data.genotypes, expected_gts[i] + ), f"Mismatch at variant {i}, expected {expected_gts[i]}, " + f"got {variant_data.genotypes}" + assert np.all(variant_data.phased) def test_iter_alleles_and_genotypes_errors(self, simple_ts): """Test error cases for iter_alleles_and_genotypes with invalid inputs.""" @@ -414,12 +431,12 @@ def insert_branch_sites(ts, m=1): ) assert len(results_default) == 1 - alleles, (gt_default, phased) = results_default[0] - assert tuple(alleles) == ("0", "1") + variant_data_default = results_default[0] + assert np.array_equal(variant_data_default.alleles, ("0", "1")) # Sample 2 should have the ancestral state (0) when isolated_as_missing=False expected_gt_default = np.array([[1], [0], [0]]) - assert np.array_equal(gt_default, expected_gt_default) + assert np.array_equal(variant_data_default.genotypes, expected_gt_default) format_obj_missing = ts.TskitFormat( ts_path, individuals_nodes=ind_nodes, isolated_as_missing=True @@ -429,12 +446,13 @@ def insert_branch_sites(ts, m=1): ) assert len(results_missing) == 1 - alleles, (gt_missing, phased) = results_missing[0] - assert tuple(alleles) == ("0", "1") + variant_data_missing = results_missing[0] + assert variant_data_missing.variant_length == 1 + assert np.array_equal(variant_data_missing.alleles, ("0", "1")) # Individual 2 should have missing values (-1) when isolated_as_missing=True expected_gt_missing = np.array([[1], [0], [-1]]) - assert np.array_equal(gt_missing, expected_gt_missing) + assert np.array_equal(variant_data_missing.genotypes, expected_gt_missing) def test_genotype_dtype_selection(self, tmp_path): tables = tskit.TableCollection(sequence_length=100) diff --git a/tests/test_vcz.py b/tests/test_vcz.py index fd33b1a8..7a5d8c31 100644 --- a/tests/test_vcz.py +++ b/tests/test_vcz.py @@ -1,5 +1,6 @@ import json +import numpy as np import numpy.testing as nt import pandas as pd import pysam @@ -896,3 +897,14 @@ def test_max_number_exceeds_dimension_size( ValueError, match=f"Max number of values {max_number} exceeds max" ): vcz.ZarrArraySpec.from_field(vcf_field, schema) + + +def test_create_index_errors(tmp_path): + root = zarr.open(tmp_path) + root["foobar"] = np.array([1, 2, 3]) + vzw = vcz.VcfZarrWriter( + None, + tmp_path, + ) + with pytest.raises(ValueError, match="Cannot create index"): + vzw.create_index()