@@ -1233,8 +1233,16 @@ def test_from_standard_tree_sequence(self):
1233
1233
assert tsutil .json_metadata_is_subset (i1 .metadata , i2 .metadata )
1234
1234
# Unless inference is perfect, internal nodes may differ, but sample nodes
1235
1235
# should be identical
1236
- for n1 , n2 in zip (ts .samples (), ts_inferred .samples ()):
1237
- assert ts .node (n1 ) == ts_inferred .node (n2 )
1236
+ for u1 , u2 in zip (ts .samples (), ts_inferred .samples ()):
1237
+ # NB - flags might differ if e.g. the node is a historical sample
1238
+ # but original ones should be maintained
1239
+ n1 = ts .node (u1 )
1240
+ n2 = ts .node (u2 )
1241
+ assert (n1 .flags & n2 .flags ) == n1 .flags # n1.flags is subset of n2.flags
1242
+ assert n1 .time == n2 .time
1243
+ assert n1 .population == n2 .population
1244
+ assert n1 .individual == n2 .individual
1245
+ assert tsutil .json_metadata_is_subset (n1 .metadata , n2 .metadata )
1238
1246
# Sites can have metadata added by the inference process, but inferred site
1239
1247
# metadata should always include all the metadata in the original ts
1240
1248
for s1 , s2 in zip (ts .sites (), ts_inferred .sites ()):
@@ -1723,7 +1731,7 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
1723
1731
ancestors_time = ancestor_data .ancestors_time [:]
1724
1732
num_ancestor_nodes = 0
1725
1733
for n in ancestors_ts .nodes ():
1726
- md = json . loads ( n .metadata ) if n . metadata else {}
1734
+ md = n .metadata
1727
1735
if tsinfer .is_pc_ancestor (n .flags ):
1728
1736
assert not ("ancestor_data_id" in md )
1729
1737
else :
@@ -2774,7 +2782,6 @@ def verify(self, ts):
2774
2782
last_node = ts1 .node (ts1 .num_nodes - 1 )
2775
2783
assert np .max (ts1 .tables .nodes .time ) == last_node .time
2776
2784
md = last_node .metadata
2777
- md = json .loads (md .decode ()) # At the moment node metadata has no schema
2778
2785
assert md .get ("ancestor_data_id" , None ) != 0
2779
2786
2780
2787
# When not post processing and there is no path compression,
@@ -2785,7 +2792,6 @@ def verify(self, ts):
2785
2792
first_node = ts2 .node (0 )
2786
2793
assert np .max (ts2 .tables .nodes .time ) == first_node .time
2787
2794
md = first_node .metadata
2788
- md = json .loads (md .decode ()) # At the moment node metadata has no schema
2789
2795
assert md ["ancestor_data_id" ] == 0
2790
2796
2791
2797
@pytest .mark .parametrize ("simp" , [True , False ])
@@ -2823,7 +2829,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
2823
2829
oldest_parent_id = ts_unsimplified .edge (- 1 ).parent
2824
2830
assert oldest_parent_id == 0
2825
2831
md = ts_unsimplified .node (oldest_parent_id ).metadata
2826
- md = json .loads (md .decode ()) # At the moment node metadata has no schema
2827
2832
assert md ["ancestor_data_id" ] == 0
2828
2833
2829
2834
# Post processing removes ancestor_data_id 0
@@ -2832,7 +2837,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
2832
2837
oldest_parent_id = ts .edge (- 1 ).parent
2833
2838
assert np .sum (ts .tables .nodes .time == ts .node (oldest_parent_id ).time ) == 1
2834
2839
md = ts .node (oldest_parent_id ).metadata
2835
- md = json .loads (md .decode ()) # At the moment node metadata has no schema
2836
2840
assert md ["ancestor_data_id" ] == 1
2837
2841
2838
2842
ts = tsinfer .post_process (
@@ -2844,7 +2848,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
2844
2848
for tree in ts .trees ():
2845
2849
roots .add (tree .root )
2846
2850
md = ts .node (tree .root ).metadata
2847
- md = json .loads (md .decode ()) # At the moment node metadata has no schema
2848
2851
assert md ["ancestor_data_id" ] == 1
2849
2852
assert len (roots ) > 1
2850
2853
@@ -3633,16 +3636,15 @@ def verify_augmented_ancestors(
3633
3636
node = t2 .nodes [m + j ]
3634
3637
assert node .flags == tsinfer .NODE_IS_SAMPLE_ANCESTOR
3635
3638
assert node .time == 1
3636
- metadata = json .loads (node .metadata .decode ())
3637
- assert node_id == metadata ["sample_data_id" ]
3639
+ assert node_id == node .metadata ["sample_data_id" ]
3638
3640
3639
3641
t2 .nodes .truncate (len (t1 .nodes ))
3640
3642
# Adding and subtracting 1 can lead to small diffs, so we compare
3641
3643
# the time separately.
3642
3644
t2 .nodes .time -= 1.0
3643
3645
assert np .allclose (t2 .nodes .time , t1 .nodes .time )
3644
3646
t2 .nodes .time = t1 .nodes .time
3645
- assert t1 .nodes == t2 .nodes
3647
+ t1 .nodes . assert_equals ( t2 .nodes , ignore_metadata = True )
3646
3648
if not path_compression :
3647
3649
# If we have path compression it's possible that some older edges
3648
3650
# will be compressed out.
@@ -3784,8 +3786,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
3784
3786
num_sample_ancestors = 0
3785
3787
for node in final_ts .nodes ():
3786
3788
if node .flags == tsinfer .NODE_IS_SAMPLE_ANCESTOR :
3787
- metadata = json .loads (node .metadata .decode ())
3788
- assert metadata ["sample_data_id" ] in subset
3789
+ assert node .metadata ["sample_data_id" ] in subset
3789
3790
num_sample_ancestors += 1
3790
3791
assert expected_sample_ancestors == num_sample_ancestors
3791
3792
tsinfer .verify (samples , final_ts .simplify ())
0 commit comments