9
9
logger = logging .getLogger (__name__ )
10
10
11
11
12
- import numpy as np
13
- import tskit
14
-
15
- def individual_nodes (ts ):
16
- """
17
- Convert a tree sequence with individuals to a 2D array of node IDs.
18
-
19
- Parameters
20
- ----------
21
- ts : tskit.TreeSequence
22
- The tree sequence to convert
23
-
24
- Returns
25
- -------
26
- numpy.ndarray
27
- Array of shape (num_individuals, max_ploidy) containing node IDs.
28
- Values of -1 indicate unused slots for individuals with ploidy
29
- less than the maximum.
30
-
31
- Raises
32
- ------
33
- ValueError
34
- If the tree sequence has no individuals, if any sample doesn't have an individual,
35
- if individuals have nodes that are both samples and non-samples, or if an
36
- individuals has no samples.
37
- """
38
- if ts .num_individuals == 0 :
39
- raise ValueError ("Tree sequence has no individuals" )
40
-
41
- individuals = np .unique (ts .nodes_individual [ts .samples ()])
42
- if len (individuals ) == 1 and individuals [0 ] == tskit .NULL :
43
- raise ValueError ("No samples refer to individuals" )
44
-
45
- # np.unique sorts the argument, so if NULL (-1) is present it will be first
46
- if individuals [0 ] == tskit .NULL :
47
- raise ValueError (
48
- "Sample nodes must all be associated with individuals"
49
- )
50
-
51
- max_ploidy = 0
52
- for i in range (ts .num_individuals ):
53
- ind = ts .individual (i )
54
- max_ploidy = max (max_ploidy , len (ind .nodes ))
55
-
56
- # Initialize output array with -1 (indicating no node)
57
- result = np .full ((ts .num_individuals , max_ploidy ), - 1 , dtype = np .int32 )
58
-
59
- for i in range (ts .num_individuals ):
60
- ind = ts .individual (i )
61
- if len (ind .nodes ) == 0 :
62
- raise ValueError (f"Individual { i } not associated with any nodes" )
63
-
64
- is_sample = {ts .node (u ).is_sample () for u in ind .nodes }
65
- if len (is_sample ) != 1 :
66
- raise ValueError (
67
- f"Individual { ind .id } has nodes that are sample and non-samples"
68
- )
69
-
70
- for j , node_id in enumerate (ind .nodes ):
71
- result [i , j ] = node_id
72
-
73
- return result
74
-
75
12
class TskitFormat (vcz .Source ):
76
- def __init__ (self , ts_path , contig_id = None , isolated_as_missing = False ):
13
+ def __init__ (
14
+ self ,
15
+ ts_path ,
16
+ individual_nodes ,
17
+ sample_ids = None ,
18
+ contig_id = None ,
19
+ isolated_as_missing = False ,
20
+ ):
77
21
self ._path = ts_path
78
22
self .ts = tskit .load (ts_path )
79
23
self .contig_id = contig_id if contig_id is not None else "1"
80
24
self .isolated_as_missing = isolated_as_missing
81
25
82
- self ._make_sample_mapping ()
83
26
self .positions = self .ts .sites_position
84
27
28
+ self ._num_samples = individual_nodes .shape [0 ]
29
+ if self ._num_samples < 1 :
30
+ raise ValueError ("individual_nodes must have at least one sample" )
31
+ self .max_ploidy = individual_nodes .shape [1 ]
32
+ if sample_ids is None :
33
+ sample_ids = [f"tsk_{ j } " for j in range (self ._num_samples )]
34
+ elif len (sample_ids ) != self ._num_samples :
35
+ raise ValueError (
36
+ f"Length of sample_ids ({ len (sample_ids )} ) does not match "
37
+ f"number of samples ({ self ._num_samples } )"
38
+ )
39
+
40
+ self ._samples = [vcz .Sample (id = sample_id ) for sample_id in sample_ids ]
41
+
42
+ self .tskit_samples = np .unique (individual_nodes [individual_nodes >= 0 ])
43
+ if len (self .tskit_samples ) < 1 :
44
+ raise ValueError ("individual_nodes must have at least one valid sample" )
45
+ node_id_to_index = {node_id : i for i , node_id in enumerate (self .tskit_samples )}
46
+ valid_mask = individual_nodes >= 0
47
+ self .sample_indices , self .ploidy_indices = np .where (valid_mask )
48
+ self .genotype_indices = np .array (
49
+ [node_id_to_index [node_id ] for node_id in individual_nodes [valid_mask ]]
50
+ )
51
+
85
52
@property
86
53
def path (self ):
87
54
return self ._path
@@ -106,21 +73,6 @@ def root_attrs(self):
106
73
def contigs (self ):
107
74
return [vcz .Contig (id = self .contig_id )]
108
75
109
- def _make_sample_mapping (self ):
110
- ts = self .ts
111
-
112
- # Use individual_nodes to get the mapping between individuals and nodes
113
- try :
114
- # Get a 2D array of node IDs for each individual
115
- self .node_ids_array = individual_nodes (ts )
116
- self ._num_samples = ts .num_individuals
117
- self .max_ploidy = self .node_ids_array .shape [1 ]
118
-
119
- except ValueError as e :
120
- raise ValueError (f"Error mapping individuals to nodes: { e } " ) from e
121
-
122
- self ._samples = [vcz .Sample (id = f"tsk_{ j } " ) for j in range (self .num_samples )]
123
-
124
76
def iter_contig (self , start , stop ):
125
77
yield from (0 for _ in range (start , stop ))
126
78
@@ -139,6 +91,7 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
139
91
isolated_as_missing = self .isolated_as_missing ,
140
92
left = self .positions [start ],
141
93
right = self .positions [stop ] if stop < self .num_records else None ,
94
+ samples = self .tskit_samples ,
142
95
):
143
96
gt = np .full (shape , constants .INT_FILL , dtype = np .int8 )
144
97
alleles = np .full (num_alleles , constants .STR_FILL , dtype = "O" )
@@ -149,12 +102,9 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
149
102
assert i < num_alleles
150
103
alleles [i ] = allele
151
104
152
- # For each individual, get genotypes for their nodes
153
- for i in range (self .num_samples ):
154
- for j in range (self .max_ploidy ):
155
- node_id = self .node_ids_array [i , j ]
156
- if node_id >= 0 : # Skip -1 entries (unused slots)
157
- gt [i , j ] = variant .genotypes [node_id ]
105
+ gt [self .sample_indices , self .ploidy_indices ] = variant .genotypes [
106
+ self .genotype_indices
107
+ ]
158
108
159
109
yield alleles , (gt , phased )
160
110
@@ -253,7 +203,9 @@ def generate_schema(
253
203
def convert (
254
204
ts_path ,
255
205
zarr_path ,
206
+ individual_nodes ,
256
207
* ,
208
+ sample_ids = None ,
257
209
contig_id = None ,
258
210
isolated_as_missing = False ,
259
211
variants_chunk_size = None ,
@@ -263,6 +215,8 @@ def convert(
263
215
):
264
216
tskit_format = TskitFormat (
265
217
ts_path ,
218
+ individual_nodes ,
219
+ sample_ids = sample_ids ,
266
220
contig_id = contig_id ,
267
221
isolated_as_missing = isolated_as_missing ,
268
222
)
0 commit comments