Skip to content

Commit 342454c

Browse files
hyanwongbenjeffery
authored andcommitted
Set permissive schema on node metadata
Also no longer need to fill empty json metadata with `{}` in tests/util.py
1 parent d48caf2 commit 342454c

File tree

3 files changed

+50
-37
lines changed

3 files changed

+50
-37
lines changed

tests/test_inference.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,8 +1233,16 @@ def test_from_standard_tree_sequence(self):
12331233
assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata)
12341234
# Unless inference is perfect, internal nodes may differ, but sample nodes
12351235
# should be identical
1236-
for n1, n2 in zip(ts.samples(), ts_inferred.samples()):
1237-
assert ts.node(n1) == ts_inferred.node(n2)
1236+
for u1, u2 in zip(ts.samples(), ts_inferred.samples()):
1237+
# NB - flags might differ if e.g. the node is a historical sample
1238+
# but original ones should be maintained
1239+
n1 = ts.node(u1)
1240+
n2 = ts.node(u2)
1241+
assert (n1.flags & n2.flags) == n1.flags # n1.flags is subset of n2.flags
1242+
assert n1.time == n2.time
1243+
assert n1.population == n2.population
1244+
assert n1.individual == n2.individual
1245+
assert tsutil.json_metadata_is_subset(n1.metadata, n2.metadata)
12381246
# Sites can have metadata added by the inference process, but inferred site
12391247
# metadata should always include all the metadata in the original ts
12401248
for s1, s2 in zip(ts.sites(), ts_inferred.sites()):
@@ -1723,7 +1731,7 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
17231731
ancestors_time = ancestor_data.ancestors_time[:]
17241732
num_ancestor_nodes = 0
17251733
for n in ancestors_ts.nodes():
1726-
md = json.loads(n.metadata) if n.metadata else {}
1734+
md = n.metadata
17271735
if tsinfer.is_pc_ancestor(n.flags):
17281736
assert not ("ancestor_data_id" in md)
17291737
else:
@@ -2774,7 +2782,6 @@ def verify(self, ts):
27742782
last_node = ts1.node(ts1.num_nodes - 1)
27752783
assert np.max(ts1.tables.nodes.time) == last_node.time
27762784
md = last_node.metadata
2777-
md = json.loads(md.decode()) # At the moment node metadata has no schema
27782785
assert md.get("ancestor_data_id", None) != 0
27792786

27802787
# When not post processing and there is no path compression,
@@ -2785,7 +2792,6 @@ def verify(self, ts):
27852792
first_node = ts2.node(0)
27862793
assert np.max(ts2.tables.nodes.time) == first_node.time
27872794
md = first_node.metadata
2788-
md = json.loads(md.decode()) # At the moment node metadata has no schema
27892795
assert md["ancestor_data_id"] == 0
27902796

27912797
@pytest.mark.parametrize("simp", [True, False])
@@ -2823,7 +2829,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
28232829
oldest_parent_id = ts_unsimplified.edge(-1).parent
28242830
assert oldest_parent_id == 0
28252831
md = ts_unsimplified.node(oldest_parent_id).metadata
2826-
md = json.loads(md.decode()) # At the moment node metadata has no schema
28272832
assert md["ancestor_data_id"] == 0
28282833

28292834
# Post processing removes ancestor_data_id 0
@@ -2832,7 +2837,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
28322837
oldest_parent_id = ts.edge(-1).parent
28332838
assert np.sum(ts.tables.nodes.time == ts.node(oldest_parent_id).time) == 1
28342839
md = ts.node(oldest_parent_id).metadata
2835-
md = json.loads(md.decode()) # At the moment node metadata has no schema
28362840
assert md["ancestor_data_id"] == 1
28372841

28382842
ts = tsinfer.post_process(
@@ -2844,7 +2848,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
28442848
for tree in ts.trees():
28452849
roots.add(tree.root)
28462850
md = ts.node(tree.root).metadata
2847-
md = json.loads(md.decode()) # At the moment node metadata has no schema
28482851
assert md["ancestor_data_id"] == 1
28492852
assert len(roots) > 1
28502853

@@ -3633,16 +3636,15 @@ def verify_augmented_ancestors(
36333636
node = t2.nodes[m + j]
36343637
assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR
36353638
assert node.time == 1
3636-
metadata = json.loads(node.metadata.decode())
3637-
assert node_id == metadata["sample_data_id"]
3639+
assert node_id == node.metadata["sample_data_id"]
36383640

36393641
t2.nodes.truncate(len(t1.nodes))
36403642
# Adding and subtracting 1 can lead to small diffs, so we compare
36413643
# the time separately.
36423644
t2.nodes.time -= 1.0
36433645
assert np.allclose(t2.nodes.time, t1.nodes.time)
36443646
t2.nodes.time = t1.nodes.time
3645-
assert t1.nodes == t2.nodes
3647+
t1.nodes.assert_equals(t2.nodes, ignore_metadata=True)
36463648
if not path_compression:
36473649
# If we have path compression it's possible that some older edges
36483650
# will be compressed out.
@@ -3784,8 +3786,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
37843786
num_sample_ancestors = 0
37853787
for node in final_ts.nodes():
37863788
if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR:
3787-
metadata = json.loads(node.metadata.decode())
3788-
assert metadata["sample_data_id"] in subset
3789+
assert node.metadata["sample_data_id"] in subset
37893790
num_sample_ancestors += 1
37903791
assert expected_sample_ancestors == num_sample_ancestors
37913792
tsinfer.verify(samples, final_ts.simplify())

tests/tsutil.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def add_default_schemas(ts):
5555
tables.populations[pop.id] = pop
5656
tables.individuals.metadata_schema = schema
5757
assert len(tables.individuals.metadata) == 0
58-
tables.individuals.packset_metadata([b"{}"] * ts.num_individuals)
5958
tables.sites.metadata_schema = schema
6059
assert len(tables.sites.metadata) == 0
61-
tables.sites.packset_metadata([b"{}"] * ts.num_sites)
60+
tables.nodes.metadata_schema = schema
61+
assert len(tables.nodes.metadata) == 0
6262
return tables.tree_sequence()
6363

6464

tsinfer/inference.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,21 @@
7070
],
7171
}
7272

73+
node_ancestor_data_id_metadata_definition = {
74+
"description": (
75+
"The ID of the tsinfer ancestor data node from which this node is derived."
76+
),
77+
"type": "number",
78+
}
79+
80+
node_sample_data_id_metadata_definition = {
81+
"description": (
82+
"The ID of the tsinfer sample data node from which this node is derived. "
83+
"Only present for nodes in which historical samples are treated as ancestors."
84+
),
85+
"type": "number",
86+
}
87+
7388

7489
def add_to_schema(schema, name, definition=None, required=False):
7590
"""
@@ -1550,7 +1565,7 @@ def match_ancestors(self):
15501565
logger.info("Finished ancestor matching")
15511566
return ts
15521567

1553-
def get_ancestors_tables(self):
1568+
def fill_ancestors_tables(self, tables):
15541569
"""
15551570
Return the ancestors tree sequence tables. Only inference sites are included in
15561571
this tree sequence. All nodes have the sample flag bit set, and if a node
@@ -1559,21 +1574,10 @@ def get_ancestors_tables(self):
15591574
logger.debug("Building ancestors tree sequence")
15601575
tsb = self.tree_sequence_builder
15611576

1562-
tables = tskit.TableCollection(
1563-
sequence_length=self.ancestor_data.sequence_length
1564-
)
1565-
15661577
flags, times = tsb.dump_nodes()
15671578
pc_ancestors = is_pc_ancestor(flags)
15681579
tables.nodes.set_columns(flags=flags, time=times)
15691580

1570-
# # FIXME we should do this as a struct codec?
1571-
# dict_schema = permissive_json_schema()
1572-
# dict_schema = add_to_schema(dict_schema, "ancestor_data_id",
1573-
# {"type": "integer"})
1574-
# schema = tskit.MetadataSchema(dict_schema)
1575-
# tables.nodes.schema = schema
1576-
15771581
# Add metadata for any non-PC node, pointing to the original ancestor
15781582
metadata = []
15791583
ancestor = 0
@@ -1611,16 +1615,20 @@ def get_ancestors_tables(self):
16111615
len(tables.sites),
16121616
)
16131617
)
1614-
return tables
16151618

16161619
def store_output(self):
1620+
tables = tskit.TableCollection(
1621+
sequence_length=self.ancestor_data.sequence_length
1622+
)
1623+
# We decided to use a permissive schema for the metadata, for flexibility
1624+
dict_schema = tskit.MetadataSchema.permissive_json().schema
1625+
dict_schema = add_to_schema(
1626+
dict_schema, "ancestor_data_id", node_ancestor_data_id_metadata_definition
1627+
)
1628+
tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema)
1629+
16171630
if self.num_ancestors > 0:
1618-
tables = self.get_ancestors_tables()
1619-
else:
1620-
# Allocate an empty tree sequence.
1621-
tables = tskit.TableCollection(
1622-
sequence_length=self.ancestor_data.sequence_length
1623-
)
1631+
self.fill_ancestors_tables(tables)
16241632
tables.time_units = self.time_units
16251633
return tables.tree_sequence()
16261634

@@ -2062,6 +2070,12 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
20622070
logger.debug("Building augmented ancestors tree sequence")
20632071
tsb = self.tree_sequence_builder
20642072
tables = self.ancestors_ts_tables.copy()
2073+
dict_schema = tables.nodes.metadata_schema.schema
2074+
assert dict_schema is not None
2075+
dict_schema = add_to_schema(
2076+
dict_schema, "sample_data_id", node_sample_data_id_metadata_definition
2077+
)
2078+
tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema)
20652079

20662080
flags, times = tsb.dump_nodes()
20672081
s = 0
@@ -2072,9 +2086,7 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
20722086
tables.nodes.add_row(
20732087
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
20742088
time=times[j],
2075-
metadata=_encode_raw_metadata(
2076-
{"sample_data_id": int(sample_indexes[s])}
2077-
),
2089+
metadata={"sample_data_id": int(sample_indexes[s])},
20782090
)
20792091
s += 1
20802092
else:

0 commit comments

Comments
 (0)