1
+ import logging
2
+
3
+ import humanfriendly
1
4
import numpy as np
2
5
import zarr
3
6
import bed_reader
4
7
5
8
from . import core
6
9
7
10
8
- def encode_bed_partition_genotypes (
9
- bed_path , zarr_path , start_variant , end_variant , encoder_threads = 8
10
- ):
11
- bed = bed_reader .open_bed (bed_path , num_threads = 1 )
11
+ logger = logging .getLogger (__name__ )
12
+
12
13
14
+ def encode_genotypes_slice (bed_path , zarr_path , start , stop ):
15
+ bed = bed_reader .open_bed (bed_path , num_threads = 1 )
13
16
store = zarr .DirectoryStore (zarr_path )
14
17
root = zarr .group (store = store )
15
- gt = core .BufferedArray (root ["call_genotype" ])
16
- gt_mask = core .BufferedArray (root ["call_genotype_mask" ])
17
- gt_phased = core .BufferedArray (root ["call_genotype_phased" ])
18
+ gt = core .BufferedArray (root ["call_genotype" ], start )
19
+ gt_mask = core .BufferedArray (root ["call_genotype_mask" ], start )
20
+ gt_phased = core .BufferedArray (root ["call_genotype_phased" ], start )
18
21
chunk_length = gt .array .chunks [0 ]
19
- assert start_variant % chunk_length == 0
20
-
21
- buffered_arrays = [gt , gt_phased , gt_mask ]
22
-
23
- with core .ThreadedZarrEncoder (buffered_arrays , encoder_threads ) as te :
24
- start = start_variant
25
- while start < end_variant :
26
- stop = min (start + chunk_length , end_variant )
27
- bed_chunk = bed .read (index = slice (start , stop ), dtype = "int8" ).T
28
- # Note could do this without iterating over rows, but it's a bit
29
- # simpler and the bottleneck is in the encoding step anyway. It's
30
- # also nice to have updates on the progress monitor.
31
- for values in bed_chunk :
32
- j = te .next_buffer_row ()
33
- dest = gt .buff [j ]
34
- dest [values == - 127 ] = - 1
35
- dest [values == 2 ] = 1
36
- dest [values == 1 , 0 ] = 1
37
- gt_phased .buff [j ] = False
38
- gt_mask .buff [j ] = dest == - 1
39
- core .update_progress (1 )
40
- start = stop
22
+ n = gt .array .shape [1 ]
23
+ assert start % chunk_length == 0
24
+
25
+ logger .debug (f"Reading slice { start } :{ stop } " )
26
+ chunk_start = start
27
+ while chunk_start < stop :
28
+ chunk_stop = min (chunk_start + chunk_length , stop )
29
+ logger .debug (f"Reading bed slice { chunk_start } :{ chunk_stop } " )
30
+ bed_chunk = bed .read (slice (chunk_start , chunk_stop ), dtype = np .int8 ).T
31
+ logger .debug (f"Got bed slice { humanfriendly .format_size (bed_chunk .nbytes )} " )
32
+ # Probably should do this without iterating over rows, but it's a bit
33
+ # simpler and lines up better with the array buffering API. The bottleneck
34
+ # is in the encoding anyway.
35
+ for values in bed_chunk :
36
+ j = gt .next_buffer_row ()
37
+ g = np .zeros_like (gt .buff [j ])
38
+ g [values == - 127 ] = - 1
39
+ g [values == 2 ] = 1
40
+ g [values == 1 , 0 ] = 1
41
+ gt .buff [j ] = g
42
+ j = gt_phased .next_buffer_row ()
43
+ gt_phased .buff [j ] = False
44
+ j = gt_mask .next_buffer_row ()
45
+ gt_mask .buff [j ] = gt .buff [j ] == - 1
46
+ chunk_start = chunk_stop
47
+ gt .flush ()
48
+ gt_phased .flush ()
49
+ gt_mask .flush ()
50
+ logger .debug (f"GT slice { start } :{ stop } done" )
41
51
42
52
43
53
def convert (
@@ -53,6 +63,7 @@ def convert(
53
63
n = bed .iid_count
54
64
m = bed .sid_count
55
65
del bed
66
+ logging .info (f"Scanned plink with { n } samples and { m } variants" )
56
67
57
68
# FIXME
58
69
if chunk_width is None :
@@ -81,7 +92,7 @@ def convert(
81
92
dimensions += ["ploidy" ]
82
93
a = root .empty (
83
94
"call_genotype" ,
84
- dtype = "i8 " ,
95
+ dtype = "i1 " ,
85
96
shape = list (shape ),
86
97
chunks = list (chunks ),
87
98
compressor = core .default_compressor ,
@@ -97,22 +108,52 @@ def convert(
97
108
)
98
109
a .attrs ["_ARRAY_DIMENSIONS" ] = list (dimensions )
99
110
100
- chunks_per_future = 2 # FIXME - make a parameter
101
- start = 0
102
- partitions = []
103
- while start < m :
104
- stop = min (m , start + chunk_length * chunks_per_future )
105
- partitions .append ((start , stop ))
106
- start = stop
107
- assert start == m
111
+ num_slices = max (1 , worker_processes * 4 )
112
+ slices = core .chunk_aligned_slices (a , num_slices )
113
+
114
+ total_chunks = sum (a .nchunks for a in root .values ())
108
115
109
116
progress_config = core .ProgressConfig (
110
- total = m , title = "Convert" , units = "vars " , show = show_progress
117
+ total = total_chunks , title = "Convert" , units = "chunks " , show = show_progress
111
118
)
112
119
with core .ParallelWorkManager (worker_processes , progress_config ) as pwm :
113
- for start , end in partitions :
114
- pwm .submit (encode_bed_partition_genotypes , bed_path , zarr_path , start , end )
120
+ for start , stop in slices :
121
+ pwm .submit (encode_genotypes_slice , bed_path , zarr_path , start , stop )
115
122
116
123
# TODO also add atomic swap like VCF. Should be abstracted to
117
124
# share basic code for setting up the variation dataset zarr
118
125
zarr .consolidate_metadata (zarr_path )
126
+
127
+
128
+ # FIXME do this more efficiently - currently reading the whole thing
129
+ # in for convenience, and also comparing call-by-call
130
+ def validate (bed_path , zarr_path ):
131
+ store = zarr .DirectoryStore (zarr_path )
132
+ root = zarr .group (store = store )
133
+ call_genotype = root ["call_genotype" ][:]
134
+
135
+ bed = bed_reader .open_bed (bed_path , num_threads = 1 )
136
+
137
+ assert call_genotype .shape [0 ] == bed .sid_count
138
+ assert call_genotype .shape [1 ] == bed .iid_count
139
+ bed_genotypes = bed .read (dtype = "int8" ).T
140
+ assert call_genotype .shape [0 ] == bed_genotypes .shape [0 ]
141
+ assert call_genotype .shape [1 ] == bed_genotypes .shape [1 ]
142
+ assert call_genotype .shape [2 ] == 2
143
+
144
+ row_id = 0
145
+ for bed_row , zarr_row in zip (bed_genotypes , call_genotype ):
146
+ # print("ROW", row_id)
147
+ # print(bed_row, zarr_row)
148
+ row_id += 1
149
+ for bed_call , zarr_call in zip (bed_row , zarr_row ):
150
+ if bed_call == - 127 :
151
+ assert list (zarr_call ) == [- 1 , - 1 ]
152
+ elif bed_call == 0 :
153
+ assert list (zarr_call ) == [0 , 0 ]
154
+ elif bed_call == 1 :
155
+ assert list (zarr_call ) == [1 , 0 ]
156
+ elif bed_call == 2 :
157
+ assert list (zarr_call ) == [1 , 1 ]
158
+ else : # pragma no cover
159
+ assert False
0 commit comments