diff --git a/tests/test_inference.py b/tests/test_inference.py index f986fe6f..6a401974 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1094,9 +1094,14 @@ def test_from_standard_tree_sequence(self): assert i1.flags == i2.flags assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata) # Unless inference is perfect, internal nodes may differ, but sample nodes - # should be identical + # should be identical. Node metadata is not transferred, however, and a tsinfer- + # specific node metadata schema is used (where empty is None rather than b"") + assert ( + ts_inferred.table_metadata_schemas.node + == tsinfer.formats.node_metadata_schema() + ) for n1, n2 in zip(ts.samples(), ts_inferred.samples()): - assert ts.node(n1) == ts_inferred.node(n2) + assert ts.node(n1).replace(metadata=None) == ts_inferred.node(n2) # Sites can have metadata added by the inference process, but inferred site # metadata should always include all the metadata in the original ts for s1, s2 in zip(ts.sites(), ts_inferred.sites()): @@ -1586,12 +1591,13 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None): ancestors_time = ancestor_data.ancestors_time[:] num_ancestor_nodes = 0 for n in ancestors_ts.nodes(): - md = json.loads(n.metadata) if n.metadata else {} + md = n.metadata if n.metadata else {} if tsinfer.is_pc_ancestor(n.flags): - assert not ("ancestor_data_id" in md) + if "tsinfer" in md: + assert "ancestor_data_id" not in md["tsinfer"] else: - assert "ancestor_data_id" in md - assert ancestors_time[md["ancestor_data_id"]] == n.time + assert "tsinfer" in md and "ancestor_data_id" in md["tsinfer"] + assert ancestors_time[md["tsinfer"]["ancestor_data_id"]] == n.time num_ancestor_nodes += 1 assert num_ancestor_nodes == ancestor_data.num_ancestors @@ -3114,8 +3120,7 @@ def verify_augmented_ancestors( node = t2.nodes[m + j] assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR assert node.time == 1 - metadata = json.loads(node.metadata.decode()) - assert node_id == metadata["sample_data_id"] + assert node_id == node.metadata["tsinfer"]["sample_data_id"] t2.nodes.truncate(len(t1.nodes)) # Adding and subtracting 1 can lead to small diffs, so we compare @@ -3265,8 +3270,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression): num_sample_ancestors = 0 for node in final_ts.nodes(): if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR: - metadata = json.loads(node.metadata.decode()) - assert metadata["sample_data_id"] in subset + assert node.metadata["tsinfer"]["sample_data_id"] in subset num_sample_ancestors += 1 assert expected_sample_ancestors == num_sample_ancestors tsinfer.verify(samples, final_ts.simplify()) diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 3df416ed..678d3c80 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -74,6 +74,42 @@ def permissive_json_schema(): } +def node_metadata_schema(): + # This is fixed by tsinfer: users cannot add to the node metadata + return tskit.MetadataSchema( + { + "codec": "struct", + "type": ["object", "null"], + "properties": { + "tsinfer": { + "description": "Information about node identity " + "from the tsinfer inference process", + "type": "object", + "properties": { + "ancestor_data_id": { + "description": "The corresponding ancestor ID " + "in the ancestors file created by the inference process, " + "or -1 if not applicable", + "type": "number", + "binaryFormat": "i", + "default": -1, + }, + "sample_data_id": { + "description": "The corresponding sample ID " + "in the sample data file used for inference, " + "or -1 if not applicable", + "type": "number", + "binaryFormat": "i", + "default": -1, + }, + }, + }, + }, + "additionalProperties": False, + } + ) + + def np_obj_equal(np_obj_array1, np_obj_array2): """ A replacement for np.array_equal to test equality of numpy arrays that diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 7da92646..314eaac7 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1411,12 +1411,7 @@ def get_ancestors_tree_sequence(self): pc_ancestors = is_pc_ancestor(flags) tables.nodes.set_columns(flags=flags, time=times) - # # FIXME we should do this as a struct codec? - # dict_schema = permissive_json_schema() - # dict_schema = add_to_schema(dict_schema, "ancestor_data_id", - # {"type": "integer"}) - # schema = tskit.MetadataSchema(dict_schema) - # tables.nodes.schema = schema + tables.nodes.metadata_schema = formats.node_metadata_schema() # Add metadata for any non-PC node, pointing to the original ancestor metadata = [] @@ -1425,7 +1420,11 @@ def get_ancestors_tree_sequence(self): if is_pc: metadata.append(b"") else: - metadata.append(_encode_raw_metadata({"ancestor_data_id": ancestor})) + metadata.append( + tables.nodes.metadata_schema.validate_and_encode_row( + {"tsinfer": {"ancestor_data_id": ancestor}} + ) + ) ancestor += 1 tables.nodes.packset_metadata(metadata) left, right, parent, child = tsb.dump_edges() @@ -1471,6 +1470,7 @@ def store_output(self): tables = tskit.TableCollection( sequence_length=self.ancestor_data.sequence_length ) + tables.nodes.metadata_schema = formats.node_metadata_schema() ts = tables.tree_sequence() return ts @@ -1830,9 +1830,7 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes): tables.nodes.add_row( flags=constants.NODE_IS_SAMPLE_ANCESTOR, time=times[j], - metadata=_encode_raw_metadata( - {"sample_data_id": int(sample_indexes[s])} - ), + metadata={"tsinfer": {"sample_data_id": int(sample_indexes[s])}}, ) s += 1 else: