Skip to content

Commit ead8f3c

Browse files
tomwhitejeromekelleher
authored andcommitted
Add samples module
1 parent 549cdb3 commit ead8f3c

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

vcztools/samples.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import logging
2+
from typing import Optional
3+
4+
import numpy as np
5+
6+
from vcztools.utils import search
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def parse_samples(
12+
samples: Optional[str], all_samples: np.ndarray, *, force_samples: bool = True
13+
) -> tuple[np.ndarray, Optional[np.ndarray]]:
14+
"""Parse a bcftools-style samples string.
15+
16+
Returns an array of the sample IDs, and an array indicating the selection
17+
from all samples.
18+
"""
19+
20+
if samples is None:
21+
return all_samples, None
22+
23+
exclude_samples = samples.startswith("^")
24+
samples = samples.lstrip("^")
25+
sample_ids = np.array(samples.split(","))
26+
if np.all(sample_ids == np.array("")):
27+
sample_ids = np.empty((0,))
28+
29+
unknown_samples = np.setdiff1d(sample_ids, all_samples)
30+
if len(unknown_samples) > 0:
31+
if force_samples:
32+
# remove unknown samples from sample_ids
33+
logger.warning(
34+
"subset called for sample(s) not in header: "
35+
f'{",".join(unknown_samples)}.'
36+
)
37+
sample_ids = np.delete(sample_ids, search(sample_ids, unknown_samples))
38+
else:
39+
raise ValueError(
40+
"subset called for sample(s) not in header: "
41+
f'{",".join(unknown_samples)}. '
42+
'Use "--force-samples" to ignore this error.'
43+
)
44+
45+
samples_selection = search(all_samples, sample_ids)
46+
if exclude_samples:
47+
samples_selection = np.setdiff1d(np.arange(all_samples.size), samples_selection)
48+
sample_ids = all_samples[samples_selection]
49+
return sample_ids, samples_selection

vcztools/vcf_writer.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
regions_to_chunk_indexes,
1515
regions_to_selection,
1616
)
17+
from vcztools.samples import parse_samples
1718
from vcztools.utils import (
1819
open_file_like,
19-
search,
2020
)
2121

2222
from . import _vcztools, constants, retrieval
@@ -134,48 +134,16 @@ def write_vcf(
134134
root = zarr.open(vcz, mode="r")
135135

136136
with open_file_like(output) as output:
137-
force_ac_an_header = False
138137
if samples and drop_genotypes:
139138
raise ValueError("Cannot select samples and drop genotypes.")
140139
elif drop_genotypes:
141140
sample_ids = []
142141
samples_selection = np.array([])
143-
elif samples is None:
144-
sample_ids = root["sample_id"][:]
145-
samples_selection = None
146142
else:
147-
force_ac_an_header = True
148143
all_samples = root["sample_id"][:]
149-
exclude_samples = samples.startswith("^")
150-
samples = samples.lstrip("^")
151-
sample_ids = np.array(samples.split(","))
152-
if np.all(sample_ids == np.array("")):
153-
sample_ids = np.empty((0,))
154-
155-
unknown_samples = np.setdiff1d(sample_ids, all_samples)
156-
if len(unknown_samples) > 0:
157-
if force_samples:
158-
# remove unknown samples from sample_ids
159-
logger.warning(
160-
"subset called for sample(s) not in header: "
161-
f'{",".join(unknown_samples)}.'
162-
)
163-
sample_ids = np.delete(
164-
sample_ids, search(sample_ids, unknown_samples)
165-
)
166-
else:
167-
raise ValueError(
168-
"subset called for sample(s) not in header: "
169-
f'{",".join(unknown_samples)}. '
170-
'Use "--force-samples" to ignore this error.'
171-
)
172-
173-
samples_selection = search(all_samples, sample_ids)
174-
if exclude_samples:
175-
samples_selection = np.setdiff1d(
176-
np.arange(all_samples.size), samples_selection
177-
)
178-
sample_ids = all_samples[samples_selection]
144+
sample_ids, samples_selection = parse_samples(
145+
samples, all_samples, force_samples=force_samples
146+
)
179147

180148
filter_expr = filter_mod.FilterExpression(
181149
field_names=set(root), include=include, exclude=exclude
@@ -184,6 +152,7 @@ def write_vcf(
184152

185153
if not no_header:
186154
original_header = root.attrs.get("vcf_header", None)
155+
force_ac_an_header = not drop_genotypes and samples_selection is not None
187156
vcf_header = _generate_header(
188157
root,
189158
original_header,
@@ -336,7 +305,7 @@ def c_chunk_to_vcf(
336305
if (
337306
"call_genotype_phased" in root
338307
and not drop_genotypes
339-
and (samples_selection is None or num_samples > 0)
308+
and (samples_selection is None or num_samples != 0)
340309
):
341310
gt_phased = get_vchunk_array(
342311
root["call_genotype_phased"],

0 commit comments

Comments
 (0)