Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,14 @@ def test_from_standard_tree_sequence(self):
assert i1.flags == i2.flags
assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata)
# Unless inference is perfect, internal nodes may differ, but sample nodes
# should be identical
# should be identical. Node metadata is not transferred, however, and a tsinfer-
# specific node metadata schema is used (where empty is None rather than b"")
assert (
ts_inferred.table_metadata_schemas.node
== tsinfer.formats.node_metadata_schema()
)
for n1, n2 in zip(ts.samples(), ts_inferred.samples()):
assert ts.node(n1) == ts_inferred.node(n2)
assert ts.node(n1).replace(metadata=None) == ts_inferred.node(n2)
# Sites can have metadata added by the inference process, but inferred site
# metadata should always include all the metadata in the original ts
for s1, s2 in zip(ts.sites(), ts_inferred.sites()):
Expand Down Expand Up @@ -1586,12 +1591,13 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
ancestors_time = ancestor_data.ancestors_time[:]
num_ancestor_nodes = 0
for n in ancestors_ts.nodes():
md = json.loads(n.metadata) if n.metadata else {}
md = n.metadata if n.metadata else {}
if tsinfer.is_pc_ancestor(n.flags):
assert not ("ancestor_data_id" in md)
if "tsinfer" in md:
assert "ancestor_data_id" not in md["tsinfer"]
else:
assert "ancestor_data_id" in md
assert ancestors_time[md["ancestor_data_id"]] == n.time
assert "tsinfer" in md and "ancestor_data_id" in md["tsinfer"]
assert ancestors_time[md["tsinfer"]["ancestor_data_id"]] == n.time
num_ancestor_nodes += 1
assert num_ancestor_nodes == ancestor_data.num_ancestors

Expand Down Expand Up @@ -3114,8 +3120,7 @@ def verify_augmented_ancestors(
node = t2.nodes[m + j]
assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR
assert node.time == 1
metadata = json.loads(node.metadata.decode())
assert node_id == metadata["sample_data_id"]
assert node_id == node.metadata["tsinfer"]["sample_data_id"]

t2.nodes.truncate(len(t1.nodes))
# Adding and subtracting 1 can lead to small diffs, so we compare
Expand Down Expand Up @@ -3265,8 +3270,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
num_sample_ancestors = 0
for node in final_ts.nodes():
if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR:
metadata = json.loads(node.metadata.decode())
assert metadata["sample_data_id"] in subset
assert node.metadata["tsinfer"]["sample_data_id"] in subset
num_sample_ancestors += 1
assert expected_sample_ancestors == num_sample_ancestors
tsinfer.verify(samples, final_ts.simplify())
Expand Down
36 changes: 36 additions & 0 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,42 @@ def permissive_json_schema():
}


def node_metadata_schema():
# This is fixed by tsinfer: users cannot add to the node metadata
return tskit.MetadataSchema(
{
"codec": "struct",
"type": ["object", "null"],
"properties": {
"tsinfer": {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point in the "tsinfer" top key here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was cleaner, because then when we e.g. add values inferred from tsdate we can store them in a tsdate property. That way it is clear where the different metadata values have come from, and they don't really risk overwriting each other etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addendum: I see this as a nice advantage of struct: we can be verbose in the explanation (and nesting) of the properties in the metadata, and it doesn't take up extra space in the encoding of the metadata for each row. That's my understanding, anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nice thing is that there is no storage cost to the additional level.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nice thing is that there is no storage cost to the additional level.

Yes, that's what I was trying to say (but you said it better)

"description": "Information about node identity "
"from the tsinfer inference process",
"type": "object",
"properties": {
"ancestor_data_id": {
"description": "The corresponding ancestor ID "
"in the ancestors file created by the inference process, "
"or -1 if not applicable",
"type": "number",
"binaryFormat": "i",
"default": -1,
},
"sample_data_id": {
"description": "The corresponding sample ID "
"in the sample data file used for inference, "
"or -1 if not applicable",
"type": "number",
"binaryFormat": "i",
"default": -1,
},
},
},
},
"additionalProperties": False,
}
)


def np_obj_equal(np_obj_array1, np_obj_array2):
"""
A replacement for np.array_equal to test equality of numpy arrays that
Expand Down
18 changes: 8 additions & 10 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,12 +1411,7 @@ def get_ancestors_tree_sequence(self):
pc_ancestors = is_pc_ancestor(flags)
tables.nodes.set_columns(flags=flags, time=times)

# # FIXME we should do this as a struct codec?
# dict_schema = permissive_json_schema()
# dict_schema = add_to_schema(dict_schema, "ancestor_data_id",
# {"type": "integer"})
# schema = tskit.MetadataSchema(dict_schema)
# tables.nodes.schema = schema
tables.nodes.metadata_schema = formats.node_metadata_schema()

# Add metadata for any non-PC node, pointing to the original ancestor
metadata = []
Expand All @@ -1425,7 +1420,11 @@ def get_ancestors_tree_sequence(self):
if is_pc:
metadata.append(b"")
else:
metadata.append(_encode_raw_metadata({"ancestor_data_id": ancestor}))
metadata.append(
tables.nodes.metadata_schema.validate_and_encode_row(
{"tsinfer": {"ancestor_data_id": ancestor}}
)
)
ancestor += 1
tables.nodes.packset_metadata(metadata)
left, right, parent, child = tsb.dump_edges()
Expand Down Expand Up @@ -1471,6 +1470,7 @@ def store_output(self):
tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)
tables.nodes.metadata_schema = formats.node_metadata_schema()
ts = tables.tree_sequence()
return ts

Expand Down Expand Up @@ -1830,9 +1830,7 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
tables.nodes.add_row(
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
time=times[j],
metadata=_encode_raw_metadata(
{"sample_data_id": int(sample_indexes[s])}
),
metadata={"tsinfer": {"sample_data_id": int(sample_indexes[s])}},
)
s += 1
else:
Expand Down