|
| 1 | +import logging |
| 2 | +import pathlib |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import tskit |
| 6 | + |
| 7 | +from bio2zarr import constants, core, vcz |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +class TskitFormat: |
| 13 | + def __init__(self, ts_path, contig_id=None, ploidy=None, isolated_as_missing=False): |
| 14 | + self.path = ts_path |
| 15 | + self.ts = tskit.load(ts_path) |
| 16 | + self.contig_id = contig_id if contig_id is not None else "1" |
| 17 | + self.isolated_as_missing = isolated_as_missing |
| 18 | + self.root_attrs = {} |
| 19 | + |
| 20 | + self._make_sample_mapping(ploidy) |
| 21 | + self.contigs = [vcz.Contig(id=self.contig_id)] |
| 22 | + self.num_records = self.ts.num_sites |
| 23 | + self.positions = self.ts.sites_position |
| 24 | + |
| 25 | + def _make_sample_mapping(self, ploidy): |
| 26 | + ts = self.ts |
| 27 | + self.individual_ploidies = [] |
| 28 | + self.max_ploidy = 0 |
| 29 | + |
| 30 | + if ts.num_individuals > 0 and ploidy is not None: |
| 31 | + raise ValueError( |
| 32 | + "Cannot specify ploidy when individuals are present in tables" |
| 33 | + ) |
| 34 | + |
| 35 | + # Find all sample nodes that reference individuals |
| 36 | + individuals = np.unique(ts.tables.nodes.individual[ts.samples()]) |
| 37 | + if len(individuals) == 1 and individuals[0] == tskit.NULL: |
| 38 | + # No samples refer to individuals |
| 39 | + individuals = None |
| 40 | + else: |
| 41 | + # np.unique sorts the argument, so if NULL (-1) is present it |
| 42 | + # will be the first value. |
| 43 | + if individuals[0] == tskit.NULL: |
| 44 | + raise ValueError( |
| 45 | + "Sample nodes must either all be associated with individuals " |
| 46 | + "or not associated with any individuals" |
| 47 | + ) |
| 48 | + |
| 49 | + if individuals is not None: |
| 50 | + self.sample_ids = [] |
| 51 | + for i in individuals: |
| 52 | + if i < 0 or i >= self.ts.num_individuals: |
| 53 | + raise ValueError("Invalid individual IDs provided.") |
| 54 | + ind = self.ts.individual(i) |
| 55 | + if len(ind.nodes) == 0: |
| 56 | + raise ValueError(f"Individual {i} not associated with a node") |
| 57 | + is_sample = {ts.node(u).is_sample() for u in ind.nodes} |
| 58 | + if len(is_sample) != 1: |
| 59 | + raise ValueError( |
| 60 | + f"Individual {ind.id} has nodes that are sample and " |
| 61 | + "non-samples" |
| 62 | + ) |
| 63 | + self.sample_ids.extend(ind.nodes) |
| 64 | + self.individual_ploidies.append(len(ind.nodes)) |
| 65 | + self.max_ploidy = max(self.max_ploidy, len(ind.nodes)) |
| 66 | + else: |
| 67 | + if ploidy is None: |
| 68 | + ploidy = 1 |
| 69 | + if ploidy < 1: |
| 70 | + raise ValueError("Ploidy must be >= 1") |
| 71 | + if ts.num_samples % ploidy != 0: |
| 72 | + raise ValueError("Sample size must be divisible by ploidy") |
| 73 | + self.individual_ploidies = np.full( |
| 74 | + ts.num_samples // ploidy, ploidy, dtype=np.int32 |
| 75 | + ) |
| 76 | + self.max_ploidy = ploidy |
| 77 | + self.sample_ids = np.arange(ts.num_samples, dtype=np.int32) |
| 78 | + |
| 79 | + self.num_samples = len(self.individual_ploidies) |
| 80 | + |
| 81 | + self.samples = [vcz.Sample(id=f"tsk_{j}") for j in range(self.num_samples)] |
| 82 | + |
| 83 | + def iter_alleles(self, start, stop, num_alleles): |
| 84 | + for variant in self.ts.variants( |
| 85 | + samples=self.sample_ids, |
| 86 | + isolated_as_missing=self.isolated_as_missing, |
| 87 | + left=self.positions[start], |
| 88 | + right=self.positions[stop] if stop < self.num_records else None, |
| 89 | + ): |
| 90 | + alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") |
| 91 | + for i, allele in enumerate(variant.alleles): |
| 92 | + assert i < num_alleles |
| 93 | + alleles[i] = allele |
| 94 | + yield alleles |
| 95 | + |
| 96 | + def iter_contig(self, start, stop): |
| 97 | + yield from (0 for _ in range(start, stop)) |
| 98 | + |
| 99 | + def iter_field(self, field_name, shape, start, stop): |
| 100 | + if field_name == "position": |
| 101 | + for pos in self.ts.tables.sites.position[start:stop]: |
| 102 | + yield int(pos) |
| 103 | + else: |
| 104 | + raise ValueError(f"Unknown field {field_name}") |
| 105 | + |
| 106 | + def iter_genotypes(self, shape, start, stop): |
| 107 | + gt = np.zeros(shape, dtype=np.int8) |
| 108 | + phased = np.zeros(shape[:-1], dtype=bool) |
| 109 | + |
| 110 | + for variant in self.ts.variants( |
| 111 | + samples=self.sample_ids, |
| 112 | + isolated_as_missing=self.isolated_as_missing, |
| 113 | + left=self.positions[start], |
| 114 | + right=self.positions[stop] if stop < self.num_records else None, |
| 115 | + ): |
| 116 | + genotypes = variant.genotypes |
| 117 | + |
| 118 | + sample_index = 0 |
| 119 | + for i, ploidy in enumerate(self.individual_ploidies): |
| 120 | + for j in range(ploidy): |
| 121 | + if j < self.max_ploidy: # Only fill up to max_ploidy |
| 122 | + try: |
| 123 | + gt[i, j] = genotypes[sample_index + j] |
| 124 | + except IndexError: |
| 125 | + # This can happen if the ploidy varies between individuals |
| 126 | + gt[i, j] = -2 # Fill value |
| 127 | + |
| 128 | + # In tskit, all genotypes are considered phased |
| 129 | + phased[i] = True |
| 130 | + sample_index += ploidy |
| 131 | + |
| 132 | + yield gt, phased |
| 133 | + |
| 134 | + def generate_schema( |
| 135 | + self, |
| 136 | + variants_chunk_size=None, |
| 137 | + samples_chunk_size=None, |
| 138 | + ): |
| 139 | + n = self.num_samples |
| 140 | + m = self.ts.num_sites |
| 141 | + |
| 142 | + # Determine max number of alleles |
| 143 | + max_alleles = 0 |
| 144 | + for variant in self.ts.variants(): |
| 145 | + max_alleles = max(max_alleles, len(variant.alleles)) |
| 146 | + |
| 147 | + logging.info(f"Scanned tskit with {n} samples and {m} variants") |
| 148 | + logging.info( |
| 149 | + f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}" |
| 150 | + ) |
| 151 | + |
| 152 | + schema_instance = vcz.VcfZarrSchema( |
| 153 | + format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, |
| 154 | + samples_chunk_size=samples_chunk_size, |
| 155 | + variants_chunk_size=variants_chunk_size, |
| 156 | + fields=[], |
| 157 | + ) |
| 158 | + |
| 159 | + logger.info( |
| 160 | + "Generating schema with chunks=" |
| 161 | + f"{schema_instance.variants_chunk_size, schema_instance.samples_chunk_size}" |
| 162 | + ) |
| 163 | + |
| 164 | + array_specs = [ |
| 165 | + vcz.ZarrArraySpec.new( |
| 166 | + vcf_field="position", |
| 167 | + name="variant_position", |
| 168 | + dtype="i4", |
| 169 | + shape=[m], |
| 170 | + dimensions=["variants"], |
| 171 | + chunks=[schema_instance.variants_chunk_size], |
| 172 | + description="Position of each variant", |
| 173 | + ), |
| 174 | + vcz.ZarrArraySpec.new( |
| 175 | + vcf_field=None, |
| 176 | + name="variant_allele", |
| 177 | + dtype="O", |
| 178 | + shape=[m, max_alleles], |
| 179 | + dimensions=["variants", "alleles"], |
| 180 | + chunks=[schema_instance.variants_chunk_size, max_alleles], |
| 181 | + description="Alleles for each variant", |
| 182 | + ), |
| 183 | + vcz.ZarrArraySpec.new( |
| 184 | + vcf_field=None, |
| 185 | + name="variant_contig", |
| 186 | + dtype=core.min_int_dtype(0, len(self.contigs)), |
| 187 | + shape=[m], |
| 188 | + dimensions=["variants"], |
| 189 | + chunks=[schema_instance.variants_chunk_size], |
| 190 | + description="Contig/chromosome index for each variant", |
| 191 | + ), |
| 192 | + vcz.ZarrArraySpec.new( |
| 193 | + vcf_field=None, |
| 194 | + name="call_genotype_phased", |
| 195 | + dtype="bool", |
| 196 | + shape=[m, n], |
| 197 | + dimensions=["variants", "samples"], |
| 198 | + chunks=[ |
| 199 | + schema_instance.variants_chunk_size, |
| 200 | + schema_instance.samples_chunk_size, |
| 201 | + ], |
| 202 | + description="Whether the genotype is phased", |
| 203 | + ), |
| 204 | + vcz.ZarrArraySpec.new( |
| 205 | + vcf_field=None, |
| 206 | + name="call_genotype", |
| 207 | + dtype="i1", |
| 208 | + shape=[m, n, self.max_ploidy], |
| 209 | + dimensions=["variants", "samples", "ploidy"], |
| 210 | + chunks=[ |
| 211 | + schema_instance.variants_chunk_size, |
| 212 | + schema_instance.samples_chunk_size, |
| 213 | + self.max_ploidy, |
| 214 | + ], |
| 215 | + description="Genotype for each variant and sample", |
| 216 | + ), |
| 217 | + vcz.ZarrArraySpec.new( |
| 218 | + vcf_field=None, |
| 219 | + name="call_genotype_mask", |
| 220 | + dtype="bool", |
| 221 | + shape=[m, n, self.max_ploidy], |
| 222 | + dimensions=["variants", "samples", "ploidy"], |
| 223 | + chunks=[ |
| 224 | + schema_instance.variants_chunk_size, |
| 225 | + schema_instance.samples_chunk_size, |
| 226 | + self.max_ploidy, |
| 227 | + ], |
| 228 | + description="Mask for each genotype call", |
| 229 | + ), |
| 230 | + ] |
| 231 | + schema_instance.fields = array_specs |
| 232 | + return schema_instance |
| 233 | + |
| 234 | + |
| 235 | +def convert( |
| 236 | + ts_path, |
| 237 | + zarr_path, |
| 238 | + *, |
| 239 | + contig_id=None, |
| 240 | + ploidy=None, |
| 241 | + isolated_as_missing=False, |
| 242 | + variants_chunk_size=None, |
| 243 | + samples_chunk_size=None, |
| 244 | + worker_processes=1, |
| 245 | + show_progress=False, |
| 246 | +): |
| 247 | + tskit_format = TskitFormat( |
| 248 | + ts_path, |
| 249 | + contig_id=contig_id, |
| 250 | + ploidy=ploidy, |
| 251 | + isolated_as_missing=isolated_as_missing, |
| 252 | + ) |
| 253 | + schema_instance = tskit_format.generate_schema( |
| 254 | + variants_chunk_size=variants_chunk_size, |
| 255 | + samples_chunk_size=samples_chunk_size, |
| 256 | + ) |
| 257 | + zarr_path = pathlib.Path(zarr_path) |
| 258 | + vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path) |
| 259 | + # Rough heuristic to split work up enough to keep utilisation high |
| 260 | + target_num_partitions = max(1, worker_processes * 4) |
| 261 | + vzw.init( |
| 262 | + tskit_format, |
| 263 | + target_num_partitions=target_num_partitions, |
| 264 | + schema=schema_instance, |
| 265 | + ) |
| 266 | + vzw.encode_all_partitions( |
| 267 | + worker_processes=worker_processes, |
| 268 | + show_progress=show_progress, |
| 269 | + ) |
| 270 | + vzw.finalise(show_progress) |
| 271 | + vzw.create_index() |
0 commit comments