Skip to content

Commit bf958f6

Browse files
Add prefix argument variant_id to plink conversion
1 parent 928965e commit bf958f6

File tree

2 files changed

+77
-11
lines changed

2 files changed

+77
-11
lines changed

bio2zarr/plink.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
import logging
23
import pathlib
34

@@ -9,17 +10,39 @@
910
logger = logging.getLogger(__name__)
1011

1112

13+
@dataclasses.dataclass
14+
class PlinkPaths:
15+
bed_path: pathlib.Path
16+
bim_path: pathlib.Path
17+
fam_path: pathlib.Path
18+
19+
1220
class PlinkFormat(vcz.Source):
1321
@core.requires_optional_dependency("bed_reader", "plink")
14-
def __init__(self, path):
22+
def __init__(self, prefix):
1523
import bed_reader
1624

17-
self._path = pathlib.Path(path)
18-
self.bed = bed_reader.open_bed(path, num_threads=1, count_A1=False)
25+
# TODO we will need support multiple chromosomes here to join
26+
# plinks into on big zarr. So, these will require multiple
27+
# bed and bim files, but should share a .fam
28+
self.prefix = pathlib.Path(prefix)
29+
paths = PlinkPaths(
30+
self.prefix.with_suffix(".bed"),
31+
self.prefix.with_suffix(".bim"),
32+
self.prefix.with_suffix(".fam"),
33+
)
34+
35+
self.bed = bed_reader.open_bed(
36+
paths.bed_path,
37+
bim_location=paths.bim_path,
38+
fam_location=paths.fam_path,
39+
num_threads=1,
40+
count_A1=False,
41+
)
1942

2043
@property
2144
def path(self):
22-
return self._path
45+
return self.prefix
2346

2447
@property
2548
def num_records(self):
@@ -46,6 +69,9 @@ def iter_field(self, field_name, shape, start, stop):
4669
assert field_name == "position" # Only position field is supported from plink
4770
yield from self.bed.bp_position[start:stop]
4871

72+
def iter_id(self, start, stop):
73+
yield from self.bed.sid[start:stop]
74+
4975
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
5076
ref_field = self.bed.allele_1
5177
alt_field = self.bed.allele_2
@@ -107,6 +133,18 @@ def generate_schema(
107133
dimensions=["variants", "alleles"],
108134
description=None,
109135
),
136+
vcz.ZarrArraySpec(
137+
name="variant_id",
138+
dtype="O",
139+
dimensions=["variants"],
140+
description=None,
141+
),
142+
vcz.ZarrArraySpec(
143+
name="variant_id_mask",
144+
dtype="bool",
145+
dimensions=["variants"],
146+
description=None,
147+
),
110148
vcz.ZarrArraySpec(
111149
source=None,
112150
name="variant_length",
@@ -147,20 +185,20 @@ def generate_schema(
147185

148186

149187
def convert(
150-
bed_path,
151-
zarr_path,
188+
prefix,
189+
out,
152190
*,
153191
variants_chunk_size=None,
154192
samples_chunk_size=None,
155193
worker_processes=1,
156194
show_progress=False,
157195
):
158-
plink_format = PlinkFormat(bed_path)
196+
plink_format = PlinkFormat(prefix)
159197
schema_instance = plink_format.generate_schema(
160198
variants_chunk_size=variants_chunk_size,
161199
samples_chunk_size=samples_chunk_size,
162200
)
163-
zarr_path = pathlib.Path(zarr_path)
201+
zarr_path = pathlib.Path(out)
164202
vzw = vcz.VcfZarrWriter(PlinkFormat, zarr_path)
165203
# Rough heuristic to split work up enough to keep utilisation high
166204
target_num_partitions = max(1, worker_processes * 4)

tests/test_plink.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,43 @@ def test_simulated_example(self, tmp_path):
7979
bed_path = data_path + "plink_sim_10s_100v_10pmiss.bed"
8080
fam_path = data_path + "plink_sim_10s_100v_10pmiss.fam"
8181
bim_path = data_path + "plink_sim_10s_100v_10pmiss.bim"
82-
# print(bed_path)
83-
# print(fam_path)
8482
sg_ds = sgkit.io.plink.read_plink(
8583
bed_path=bed_path, fam_path=fam_path, bim_path=bim_path
8684
)
8785
out = tmp_path / "example.plink.zarr"
88-
plink.convert(bed_path, out)
86+
plink.convert(prefix=data_path + "/plink_sim_10s_100v_10pmiss", out=out)
8987
ds = sg.load_dataset(out)
9088
nt.assert_array_equal(ds.call_genotype.values, sg_ds.call_genotype.values)
89+
nt.assert_array_equal(
90+
ds.call_genotype_mask.values, sg_ds.call_genotype_mask.values
91+
)
92+
# sgkit doesn't have phased
93+
nt.assert_array_equal(ds.variant_position.values, sg_ds.variant_position.values)
94+
nt.assert_array_equal(
95+
ds.variant_allele.values, sg_ds.variant_allele.values.astype("U")
96+
)
97+
nt.assert_array_equal(ds.variant_contig.values, sg_ds.variant_contig.values)
98+
nt.assert_array_equal(ds.variant_id.values, sg_ds.variant_id.values)
99+
# print(sg_ds.variant_id.values)
100+
101+
# Can't compare to sgkit because of
102+
# https://github.com/sgkit-dev/sgkit/issues/1314
103+
nt.assert_array_equal(
104+
ds.sample_id.values,
105+
[
106+
"000",
107+
"001",
108+
"002",
109+
"003",
110+
"004",
111+
"005",
112+
"006",
113+
"007",
114+
"008",
115+
"009",
116+
],
117+
)
118+
# We don't do the additional sample_ fields yet
91119

92120

93121
class TestExample:

0 commit comments

Comments
 (0)