Skip to content

Commit 070a729

Browse files
Merge pull request #507 from jeromekelleher/update-node-date
Update node date
2 parents 7758245 + 63abe75 commit 070a729

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

sc2ts/stats.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,30 @@ def node_data(ts, inheritance_stats=True):
4444

4545
md = ts.nodes_metadata
4646
cols = {k: md[k].astype(str) for k in md.dtype.names}
47+
dtype = {k: pd.StringDtype() for k in md.dtype.names}
4748
flags = ts.nodes_flags
4849
cols["node_id"] = np.arange(ts.num_nodes)
50+
dtype["node_id"] = "int"
4951
cols["is_sample"] = (flags & tskit.NODE_IS_SAMPLE) > 0
52+
dtype["is_sample"] = "bool"
5053
cols["is_recombinant"] = (flags & core.NODE_IS_RECOMBINANT) > 0
54+
dtype["is_recombinant"] = "bool"
5155
# Are other flags useful of just debug info? Lets leave them out
5256
# for now.
5357
cols["num_mutations"] = np.bincount(ts.mutations_node, minlength=ts.num_nodes)
58+
dtype["num_mutations"] = "int"
5459
# This is the same as is_recombinant but less obvious
5560
# cols["num_parents"] = np.bincount(ts.edges_child,
5661
# minlength=ts.num_edges)
5762

5863
if inheritance_stats:
5964
counter = jit.count(ts)
6065
cols["max_descendant_samples"] = counter.nodes_max_descendant_samples
61-
cols["date"] = convert_date(ts, ts.nodes_time)
62-
return pd.DataFrame(cols)
66+
dtype["max_descendant_samples"] = "int"
67+
if "time_zero_date" in ts.metadata:
68+
cols["date"] = convert_date(ts, ts.nodes_time)
69+
# Let Pandas infer the dtype of this to get the appropriate date type
70+
return pd.DataFrame(cols).astype(dtype)
6371

6472

6573
def mutation_data(ts, inheritance_stats=True):
@@ -77,10 +85,15 @@ def mutation_data(ts, inheritance_stats=True):
7785
cols["node"] = ts.mutations_node
7886
cols["inherited_state"] = inherited_state
7987
cols["derived_state"] = derived_state
80-
cols["date"] = convert_date(ts, ts.mutations_time)
88+
if "time_zero_date" in ts.metadata:
89+
cols["date"] = convert_date(ts, ts.mutations_time)
8190
if inheritance_stats:
8291
counter = jit.count(ts)
8392
cols["num_descendants"] = counter.mutations_num_descendants
8493
cols["num_inheritors"] = counter.mutations_num_inheritors
8594

86-
return pd.DataFrame(cols)
95+
dtype = {k: "int" for k in cols if k != "date"}
96+
dtype["inherited_state"] = pd.StringDtype()
97+
dtype["derived_state"] = pd.StringDtype()
98+
99+
return pd.DataFrame(cols).astype(dtype)

tests/test_info.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ def fx_ti_2020_02_15(fx_ts_map):
2929
@pytest.fixture
3030
def fx_ts_min_2020_02_15(fx_ts_map):
3131
ts = fx_ts_map["2020-02-15"]
32-
return sc2ts.minimise_metadata(ts)
32+
field_mapping = {
33+
"strain": "sample_id",
34+
"Viridian_pangolin": "pango",
35+
"Viridian_scorpio": "scorpio",
36+
}
37+
return sc2ts.minimise_metadata(ts, field_mapping)
3338

3439

3540
@pytest.fixture
@@ -55,6 +60,7 @@ def test_copying_table(self, fx_ti_recombinant_example_1):
5560
assert "TestChild" in ct_via_ti
5661
assert ct_via_ti == ct_via_ts
5762

63+
5864
class TestTreeInfo:
5965
def test_tree_info_values(self, fx_ti_2020_02_13):
6066
ti = fx_ti_2020_02_13
@@ -229,18 +235,30 @@ class TestDataFuncs:
229235

230236
def test_example_node(self, fx_ts_min_2020_02_15, fx_ti_2020_02_15):
231237
ts = fx_ts_min_2020_02_15
232-
df = sc2ts.node_data(fx_ts_min_2020_02_15)
233238
ti = fx_ti_2020_02_15
239+
df = sc2ts.node_data(ts)
234240
assert df.shape[0] == ti.ts.num_nodes
235241
nt.assert_array_equal(ti.nodes_num_mutations, df["num_mutations"])
236242
nt.assert_array_equal(np.arange(ti.ts.num_nodes), df["node_id"])
237243
nt.assert_array_equal(
238244
ti.nodes_max_descendant_samples, df["max_descendant_samples"]
239245
)
246+
print(ti.nodes_date.dtype)
247+
print(df["date"].dtype)
240248
nt.assert_array_equal(ti.nodes_date, df["date"])
241249
assert list(np.where(df["is_recombinant"])[0]) == list(ti.recombinants)
242250
assert list(np.where(df["is_sample"])[0]) == list(ts.samples())
243251

252+
def test_example_node_no_date(self, fx_ts_min_2020_02_15, fx_ti_2020_02_15):
253+
ts = fx_ts_min_2020_02_15
254+
df1 = sc2ts.node_data(ts)
255+
tables = ts.dump_tables()
256+
tables.metadata = {}
257+
df2 = sc2ts.node_data(tables.tree_sequence())
258+
assert set(df1) == set(df2) | {"date"}
259+
for col in df2:
260+
nt.assert_array_equal(df1[col].values, df2[col].values)
261+
244262
def test_example_mutation(self, fx_ts_min_2020_02_15, fx_ti_2020_02_15):
245263
ts = fx_ts_min_2020_02_15
246264
df = sc2ts.mutation_data(fx_ts_min_2020_02_15)
@@ -253,3 +271,13 @@ def test_example_mutation(self, fx_ts_min_2020_02_15, fx_ti_2020_02_15):
253271
nt.assert_array_equal(ti.mutations_inherited_state, df["inherited_state"])
254272
nt.assert_array_equal(ts.mutations_node, df["node"])
255273
nt.assert_array_equal(ts.mutations_parent, df["parent"])
274+
275+
def test_example_mutation_no_date(self, fx_ts_min_2020_02_15, fx_ti_2020_02_15):
276+
ts = fx_ts_min_2020_02_15
277+
df1 = sc2ts.mutation_data(ts)
278+
tables = ts.dump_tables()
279+
tables.metadata = {}
280+
df2 = sc2ts.mutation_data(tables.tree_sequence())
281+
assert set(df1) == set(df2) | {"date"}
282+
for col in df2:
283+
nt.assert_array_equal(df1[col].values, df2[col].values)

0 commit comments

Comments
 (0)