Skip to content

Commit 3786aea

Browse files
committed
Use struct for node metadata
1 parent 99cad13 commit 3786aea

File tree

3 files changed

+44
-17
lines changed

3 files changed

+44
-17
lines changed

tests/test_inference.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,9 +1094,11 @@ def test_from_standard_tree_sequence(self):
10941094
assert i1.flags == i2.flags
10951095
assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata)
10961096
# Unless inference is perfect, internal nodes may differ, but sample nodes
1097-
# should be identical
1097+
# should be identical. Node metadata is not transferred, however, and a tsinfer-
1098+
# specific node metadata schema is used (where empty is None rather than b"")
1099+
assert ts.table_metadata_schemas.node == tsinfer.formats.node_metadata_schema()
10981100
for n1, n2 in zip(ts.samples(), ts_inferred.samples()):
1099-
assert ts.node(n1) == ts_inferred.node(n2)
1101+
assert ts.node(n1).replace(metadata=None) == ts_inferred.node(n2)
11001102
# Sites can have metadata added by the inference process, but inferred site
11011103
# metadata should always include all the metadata in the original ts
11021104
for s1, s2 in zip(ts.sites(), ts_inferred.sites()):
@@ -1586,7 +1588,7 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
15861588
ancestors_time = ancestor_data.ancestors_time[:]
15871589
num_ancestor_nodes = 0
15881590
for n in ancestors_ts.nodes():
1589-
md = json.loads(n.metadata) if n.metadata else {}
1591+
md = n.metadata if n.metadata else {}
15901592
if tsinfer.is_pc_ancestor(n.flags):
15911593
assert not ("ancestor_data_id" in md)
15921594
else:
@@ -3114,8 +3116,7 @@ def verify_augmented_ancestors(
31143116
node = t2.nodes[m + j]
31153117
assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR
31163118
assert node.time == 1
3117-
metadata = json.loads(node.metadata.decode())
3118-
assert node_id == metadata["sample_data_id"]
3119+
assert node_id == node.metadata["sample_data_id"]
31193120

31203121
t2.nodes.truncate(len(t1.nodes))
31213122
# Adding and subtracting 1 can lead to small diffs, so we compare
@@ -3265,8 +3266,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
32653266
num_sample_ancestors = 0
32663267
for node in final_ts.nodes():
32673268
if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR:
3268-
metadata = json.loads(node.metadata.decode())
3269-
assert metadata["sample_data_id"] in subset
3269+
assert node.metadata["sample_data_id"] in subset
32703270
num_sample_ancestors += 1
32713271
assert expected_sample_ancestors == num_sample_ancestors
32723272
tsinfer.verify(samples, final_ts.simplify())

tsinfer/formats.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ def permissive_json_schema():
7474
}
7575

7676

77+
def node_metadata_schema():
78+
# This is fixed by tsinfer: users cannot add to the node metadata
79+
return tskit.MetadataSchema(
80+
{
81+
"codec": "struct",
82+
"type": ["object", "null"],
83+
"properties": {
84+
"ancestor_data_id": {
85+
"description": "",
86+
"type": "integer",
87+
"binaryFormat": "i",
88+
"default": -1,
89+
},
90+
"sample_data_id": {
91+
"description": "Date of sample collection in ISO format",
92+
"type": "integer",
93+
"binaryFormat": "i",
94+
"default": -1,
95+
},
96+
},
97+
"required": ["ancestor_data_id", "sample_data_id"],
98+
"additionalProperties": False,
99+
}
100+
)
101+
102+
77103
def np_obj_equal(np_obj_array1, np_obj_array2):
78104
"""
79105
A replacement for np.array_equal to test equality of numpy arrays that

tsinfer/inference.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,12 +1411,7 @@ def get_ancestors_tree_sequence(self):
14111411
pc_ancestors = is_pc_ancestor(flags)
14121412
tables.nodes.set_columns(flags=flags, time=times)
14131413

1414-
# # FIXME we should do this as a struct codec?
1415-
# dict_schema = permissive_json_schema()
1416-
# dict_schema = add_to_schema(dict_schema, "ancestor_data_id",
1417-
# {"type": "integer"})
1418-
# schema = tskit.MetadataSchema(dict_schema)
1419-
# tables.nodes.schema = schema
1414+
tables.nodes.metadata_schema = formats.node_metadata_schema()
14201415

14211416
# Add metadata for any non-PC node, pointing to the original ancestor
14221417
metadata = []
@@ -1425,7 +1420,11 @@ def get_ancestors_tree_sequence(self):
14251420
if is_pc:
14261421
metadata.append(b"")
14271422
else:
1428-
metadata.append(_encode_raw_metadata({"ancestor_data_id": ancestor}))
1423+
metadata.append(
1424+
tables.nodes.metadata_schema.validate_and_encode_row(
1425+
{"ancestor_data_id": ancestor, "sample_data_id": tskit.NULL}
1426+
)
1427+
)
14291428
ancestor += 1
14301429
tables.nodes.packset_metadata(metadata)
14311430
left, right, parent, child = tsb.dump_edges()
@@ -1471,6 +1470,7 @@ def store_output(self):
14711470
tables = tskit.TableCollection(
14721471
sequence_length=self.ancestor_data.sequence_length
14731472
)
1473+
tables.nodes.metadata_schema = formats.node_metadata_schema()
14741474
ts = tables.tree_sequence()
14751475
return ts
14761476

@@ -1830,9 +1830,10 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
18301830
tables.nodes.add_row(
18311831
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
18321832
time=times[j],
1833-
metadata=_encode_raw_metadata(
1834-
{"sample_data_id": int(sample_indexes[s])}
1835-
),
1833+
metadata={
1834+
"ancestor_data_id": tskit.NULL,
1835+
"sample_data_id": int(sample_indexes[s]),
1836+
},
18361837
)
18371838
s += 1
18381839
else:

0 commit comments

Comments
 (0)