Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vcztools/plink.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pandas as pd
import zarr

from vcztools.utils import _as_fixed_length_unicode

from . import _vcztools, retrieval


Expand All @@ -24,7 +26,7 @@ def encode_genotypes(genotypes, a12_allele=None):

def generate_fam(root):
# TODO generate an error if sample_id contains a space
sample_id = root["sample_id"][:].astype(str)
sample_id = _as_fixed_length_unicode(root["sample_id"][:])
zeros = np.zeros(sample_id.shape, dtype=int)
df = pd.DataFrame(
{
Expand All @@ -41,8 +43,8 @@ def generate_fam(root):

def generate_bim(root, a12_allele):
select = a12_allele[:, 1] != -1
contig_id = root["contig_id"][:].astype(str)
alleles = root["variant_allele"][:].astype(str)[select]
contig_id = _as_fixed_length_unicode(root["contig_id"][:])
alleles = _as_fixed_length_unicode(root["variant_allele"][:])[select]
a12_allele = a12_allele[select]
num_variants = np.sum(select)
allele_1 = alleles[np.arange(num_variants), a12_allele[:, 0]]
Expand Down
3 changes: 2 additions & 1 deletion vcztools/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
regions_to_selection,
)
from vcztools.samples import parse_samples
from vcztools.utils import _as_fixed_length_unicode


# NOTE: this class is just a skeleton for now. The idea is that this
Expand Down Expand Up @@ -86,7 +87,7 @@ def variant_chunk_index_iter(root, regions=None, targets=None):
yield v_chunk, v_mask_chunk

else:
contigs_u = root["contig_id"][:].astype("U").tolist()
contigs_u = _as_fixed_length_unicode(root["contig_id"][:]).tolist()
regions_pyranges = parse_regions(regions, contigs_u)
targets_pyranges, complement = parse_targets(targets, contigs_u)

Expand Down
7 changes: 5 additions & 2 deletions vcztools/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from vcztools.utils import search
from vcztools.utils import _as_fixed_length_unicode, search

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -30,7 +30,7 @@ def parse_samples(
sample_ids = np.array(samples.split(","))

if np.all(sample_ids == np.array("")):
sample_ids = np.empty((0,))
sample_ids = np.empty((0,), dtype=np.dtypes.StringDType())

unknown_samples = np.setdiff1d(sample_ids, all_samples)
if len(unknown_samples) > 0:
Expand All @@ -48,6 +48,9 @@ def parse_samples(
'Use "--force-samples" to ignore this error.'
)

all_samples = _as_fixed_length_unicode(all_samples)
sample_ids = _as_fixed_length_unicode(sample_ids)

samples_selection = search(all_samples, sample_ids)
if exclude_samples:
samples_selection = np.setdiff1d(np.arange(all_samples.size), samples_selection)
Expand Down
4 changes: 2 additions & 2 deletions vcztools/stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import zarr

from vcztools.utils import open_file_like
from vcztools.utils import _as_fixed_length_unicode, open_file_like


def nrecords(vcz, output):
Expand All @@ -22,7 +22,7 @@ def stats(vcz, output):
)

with open_file_like(output) as output:
contigs = root["contig_id"][:].astype("U").tolist()
contigs = _as_fixed_length_unicode(root["contig_id"][:]).tolist()
if "contig_length" in root:
contig_lengths = root["contig_length"][:]
else:
Expand Down
20 changes: 20 additions & 0 deletions vcztools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,23 @@ def vcf_name_to_vcz_names(vcz_names: set[str], vcf_name: str) -> list[str]:
matches.append(RESERVED_VCF_FIELDS[vcf_name])

return matches


# See https://numpy.org/devdocs/user/basics.strings.html#casting-to-and-from-fixed-width-strings


def _max_len(arr: np.ndarray) -> int:
lengths = np.strings.str_len(arr) # numpy 2
max_len = int(np.max(lengths)) if lengths.size > 0 else 1
return max(max_len, 1)


def _as_fixed_length_string(arr: np.ndarray) -> np.ndarray:
# convert from StringDType to a fixed-length null-terminated byte sequence
# (character code S)
return arr.astype(f"S{_max_len(arr)}")


def _as_fixed_length_unicode(arr: np.ndarray) -> np.ndarray:
# convert from StringDType to a fixed-length unicode string (character code U)
return arr.astype(f"U{_max_len(arr)}")
27 changes: 15 additions & 12 deletions vcztools/vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from vcztools.samples import parse_samples
from vcztools.utils import (
_as_fixed_length_string,
_as_fixed_length_unicode,
open_file_like,
)

Expand Down Expand Up @@ -68,7 +70,8 @@


def dims(arr):
return arr.attrs["_ARRAY_DIMENSIONS"]
# Zarr format v2 has _ARRAY_DIMENSIONS, v3 has dedicated metadata
return arr.attrs.get("_ARRAY_DIMENSIONS", None) or arr.metadata.dimension_names


def write_vcf(
Expand Down Expand Up @@ -119,7 +122,7 @@ def write_vcf(
if header_only:
return

contigs = root["contig_id"][:].astype("S")
contigs = _as_fixed_length_string(root["contig_id"][:])
filters = get_filter_ids(root)

for chunk_data in retrieval.variant_chunk_iter(
Expand Down Expand Up @@ -166,7 +169,7 @@ def c_chunk_to_vcf(

# Optional fields which we fill in with "all missing" defaults
if "variant_id" in chunk_data:
id = chunk_data["variant_id"].astype("S")
id = _as_fixed_length_string(chunk_data["variant_id"])
else:
id = np.array(["."] * num_variants, dtype="S")
if "variant_quality" in chunk_data:
Expand Down Expand Up @@ -211,8 +214,8 @@ def c_chunk_to_vcf(
vcf_name = name[len("variant_") :]
info_fields[vcf_name] = array

ref = alleles[:, 0].astype("S")
alt = alleles[:, 1:].astype("S")
ref = _as_fixed_length_string(alleles[:, 0])
alt = _as_fixed_length_string(alleles[:, 1:])

if len(id.shape) == 1:
id = id.reshape((-1, 1))
Expand Down Expand Up @@ -246,16 +249,16 @@ def c_chunk_to_vcf(
encoder.add_gt_field(gt, gt_phased)
for name, zarray in info_fields.items():
# print(array.dtype.kind)
if zarray.dtype.kind in ("O", "U"):
zarray = zarray.astype("S")
if zarray.dtype.kind in ("O", "U", "T"):
zarray = _as_fixed_length_string(zarray)
if len(zarray.shape) == 1:
zarray = zarray.reshape((num_variants, 1))
encoder.add_info_field(name, zarray)

if num_samples != 0:
for name, zarray in format_fields.items():
if zarray.dtype.kind in ("O", "U"):
zarray = zarray.astype("S")
if zarray.dtype.kind in ("O", "U", "T"):
zarray = _as_fixed_length_string(zarray)
if len(zarray.shape) == 2:
zarray = zarray.reshape((num_variants, num_samples, 1))
encoder.add_format_field(name, zarray)
Expand All @@ -281,7 +284,7 @@ def get_filter_ids(root):
does not exist, return a single filter "PASS" by default.
"""
if "filter_id" in root:
filters = root["filter_id"][:].astype("S")
filters = _as_fixed_length_string(root["filter_id"][:])
else:
filters = np.array(["PASS"], dtype="S")
return filters
Expand All @@ -297,7 +300,7 @@ def _generate_header(
output = io.StringIO()

contigs = list(ds["contig_id"][:])
filters = list(get_filter_ids(ds).astype("U"))
filters = list(_as_fixed_length_unicode(get_filter_ids(ds)))
info_fields = []
format_fields = []

Expand Down Expand Up @@ -470,7 +473,7 @@ def _array_to_vcf_type(a):
return "Float"
elif a.dtype.str[1:] in ("S1", "U1"):
return "Character"
elif a.dtype.kind in ("O", "S", "U"):
elif a.dtype.kind in ("O", "S", "U", "T"):
return "String"
else:
raise ValueError(f"Unsupported dtype: {a.dtype}")
Expand Down
Loading