diff --git a/vcztools/plink.py b/vcztools/plink.py index 7951661..117566f 100644 --- a/vcztools/plink.py +++ b/vcztools/plink.py @@ -8,6 +8,8 @@ import pandas as pd import zarr +from vcztools.utils import _as_fixed_length_unicode + from . import _vcztools, retrieval @@ -24,7 +26,7 @@ def encode_genotypes(genotypes, a12_allele=None): def generate_fam(root): # TODO generate an error if sample_id contains a space - sample_id = root["sample_id"][:].astype(str) + sample_id = _as_fixed_length_unicode(root["sample_id"][:]) zeros = np.zeros(sample_id.shape, dtype=int) df = pd.DataFrame( { @@ -41,8 +43,8 @@ def generate_fam(root): def generate_bim(root, a12_allele): select = a12_allele[:, 1] != -1 - contig_id = root["contig_id"][:].astype(str) - alleles = root["variant_allele"][:].astype(str)[select] + contig_id = _as_fixed_length_unicode(root["contig_id"][:]) + alleles = _as_fixed_length_unicode(root["variant_allele"][:])[select] a12_allele = a12_allele[select] num_variants = np.sum(select) allele_1 = alleles[np.arange(num_variants), a12_allele[:, 0]] diff --git a/vcztools/retrieval.py b/vcztools/retrieval.py index 4fde829..620da10 100644 --- a/vcztools/retrieval.py +++ b/vcztools/retrieval.py @@ -11,6 +11,7 @@ regions_to_selection, ) from vcztools.samples import parse_samples +from vcztools.utils import _as_fixed_length_unicode # NOTE: this class is just a skeleton for now. The idea is that this @@ -86,7 +87,7 @@ def variant_chunk_index_iter(root, regions=None, targets=None): yield v_chunk, v_mask_chunk else: - contigs_u = root["contig_id"][:].astype("U").tolist() + contigs_u = _as_fixed_length_unicode(root["contig_id"][:]).tolist() regions_pyranges = parse_regions(regions, contigs_u) targets_pyranges, complement = parse_targets(targets, contigs_u) diff --git a/vcztools/samples.py b/vcztools/samples.py index deef1e3..71e7edf 100644 --- a/vcztools/samples.py +++ b/vcztools/samples.py @@ -2,7 +2,7 @@ import numpy as np -from vcztools.utils import search +from vcztools.utils import _as_fixed_length_unicode, search logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ def parse_samples( sample_ids = np.array(samples.split(",")) if np.all(sample_ids == np.array("")): - sample_ids = np.empty((0,)) + sample_ids = np.empty((0,), dtype=np.dtypes.StringDType()) unknown_samples = np.setdiff1d(sample_ids, all_samples) if len(unknown_samples) > 0: @@ -48,6 +48,9 @@ def parse_samples( 'Use "--force-samples" to ignore this error.' ) + all_samples = _as_fixed_length_unicode(all_samples) + sample_ids = _as_fixed_length_unicode(sample_ids) + samples_selection = search(all_samples, sample_ids) if exclude_samples: samples_selection = np.setdiff1d(np.arange(all_samples.size), samples_selection) diff --git a/vcztools/stats.py b/vcztools/stats.py index 0aac357..2b78f7f 100644 --- a/vcztools/stats.py +++ b/vcztools/stats.py @@ -1,7 +1,7 @@ import numpy as np import zarr -from vcztools.utils import open_file_like +from vcztools.utils import _as_fixed_length_unicode, open_file_like def nrecords(vcz, output): @@ -22,7 +22,7 @@ def stats(vcz, output): ) with open_file_like(output) as output: - contigs = root["contig_id"][:].astype("U").tolist() + contigs = _as_fixed_length_unicode(root["contig_id"][:]).tolist() if "contig_length" in root: contig_lengths = root["contig_length"][:] else: diff --git a/vcztools/utils.py b/vcztools/utils.py index 2c7fe19..57f6590 100644 --- a/vcztools/utils.py +++ b/vcztools/utils.py @@ -59,3 +59,23 @@ def vcf_name_to_vcz_names(vcz_names: set[str], vcf_name: str) -> list[str]: matches.append(RESERVED_VCF_FIELDS[vcf_name]) return matches + + +# See https://numpy.org/devdocs/user/basics.strings.html#casting-to-and-from-fixed-width-strings + + +def _max_len(arr: np.ndarray) -> int: + lengths = np.strings.str_len(arr) # numpy 2 + max_len = int(np.max(lengths)) if lengths.size > 0 else 1 + return max(max_len, 1) + + +def _as_fixed_length_string(arr: np.ndarray) -> np.ndarray: + # convert from StringDType to a fixed-length null-terminated byte sequence + # (character code S) + return arr.astype(f"S{_max_len(arr)}") + + +def _as_fixed_length_unicode(arr: np.ndarray) -> np.ndarray: + # convert from StringDType to a fixed-length unicode string (character code U) + return arr.astype(f"U{_max_len(arr)}") diff --git a/vcztools/vcf_writer.py b/vcztools/vcf_writer.py index 8b8b0cd..873ea99 100644 --- a/vcztools/vcf_writer.py +++ b/vcztools/vcf_writer.py @@ -8,6 +8,8 @@ from vcztools.samples import parse_samples from vcztools.utils import ( + _as_fixed_length_string, + _as_fixed_length_unicode, open_file_like, ) @@ -68,7 +70,8 @@ def dims(arr): - return arr.attrs["_ARRAY_DIMENSIONS"] + # Zarr format v2 has _ARRAY_DIMENSIONS, v3 has dedicated metadata + return arr.attrs.get("_ARRAY_DIMENSIONS", None) or arr.metadata.dimension_names def write_vcf( @@ -119,7 +122,7 @@ def write_vcf( if header_only: return - contigs = root["contig_id"][:].astype("S") + contigs = _as_fixed_length_string(root["contig_id"][:]) filters = get_filter_ids(root) for chunk_data in retrieval.variant_chunk_iter( @@ -166,7 +169,7 @@ def c_chunk_to_vcf( # Optional fields which we fill in with "all missing" defaults if "variant_id" in chunk_data: - id = chunk_data["variant_id"].astype("S") + id = _as_fixed_length_string(chunk_data["variant_id"]) else: id = np.array(["."] * num_variants, dtype="S") if "variant_quality" in chunk_data: @@ -211,8 +214,8 @@ def c_chunk_to_vcf( vcf_name = name[len("variant_") :] info_fields[vcf_name] = array - ref = alleles[:, 0].astype("S") - alt = alleles[:, 1:].astype("S") + ref = _as_fixed_length_string(alleles[:, 0]) + alt = _as_fixed_length_string(alleles[:, 1:]) if len(id.shape) == 1: id = id.reshape((-1, 1)) @@ -246,16 +249,16 @@ def c_chunk_to_vcf( encoder.add_gt_field(gt, gt_phased) for name, zarray in info_fields.items(): # print(array.dtype.kind) - if zarray.dtype.kind in ("O", "U"): - zarray = zarray.astype("S") + if zarray.dtype.kind in ("O", "U", "T"): + zarray = _as_fixed_length_string(zarray) if len(zarray.shape) == 1: zarray = zarray.reshape((num_variants, 1)) encoder.add_info_field(name, zarray) if num_samples != 0: for name, zarray in format_fields.items(): - if zarray.dtype.kind in ("O", "U"): - zarray = zarray.astype("S") + if zarray.dtype.kind in ("O", "U", "T"): + zarray = _as_fixed_length_string(zarray) if len(zarray.shape) == 2: zarray = zarray.reshape((num_variants, num_samples, 1)) encoder.add_format_field(name, zarray) @@ -281,7 +284,7 @@ def get_filter_ids(root): does not exist, return a single filter "PASS" by default. """ if "filter_id" in root: - filters = root["filter_id"][:].astype("S") + filters = _as_fixed_length_string(root["filter_id"][:]) else: filters = np.array(["PASS"], dtype="S") return filters @@ -297,7 +300,7 @@ def _generate_header( output = io.StringIO() contigs = list(ds["contig_id"][:]) - filters = list(get_filter_ids(ds).astype("U")) + filters = list(_as_fixed_length_unicode(get_filter_ids(ds))) info_fields = [] format_fields = [] @@ -470,7 +473,7 @@ def _array_to_vcf_type(a): return "Float" elif a.dtype.str[1:] in ("S1", "U1"): return "Character" - elif a.dtype.kind in ("O", "S", "U"): + elif a.dtype.kind in ("O", "S", "U", "T"): return "String" else: raise ValueError(f"Unsupported dtype: {a.dtype}")