Skip to content

Commit 2b26ebf

Browse files
committed
push/pull: replace dict holding column information with custom class
1 parent 8647ddd commit 2b26ebf

File tree

2 files changed

+80
-81
lines changed

2 files changed

+80
-81
lines changed

src/mudata/_core/mudata.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from .file_backing import MuDataFileManager
2828
from .repr import MUDATA_CSS, block_matrix, details_block_table
2929
from .utils import (
30+
MetadataColumn,
3031
_classify_attr_columns,
31-
_classify_prefixed_columns,
3232
_make_index_unique,
3333
_maybe_coerce_to_bool,
3434
_maybe_coerce_to_boolean,
@@ -1940,9 +1940,7 @@ def _pull_attr(
19401940
# - column -> [modname1:column, modname2:column, ...]
19411941
cols = {
19421942
prefix: [
1943-
col
1944-
for col in modcols
1945-
if col["name"] in columns or col["derived_name"] in columns
1943+
col for col in modcols if col.name in columns or col.derived_name in columns
19461944
]
19471945
for prefix, modcols in cols.items()
19481946
}
@@ -1960,15 +1958,15 @@ def _pull_attr(
19601958

19611959
selector = {"common": common, "nonunique": nonunique, "unique": unique}
19621960
cols = {
1963-
prefix: [col for col in modcols if selector[col["class"]]]
1961+
prefix: [col for col in modcols if selector[col.klass]]
19641962
for prefix, modcols in cols.items()
19651963
}
19661964

19671965
if mods is not None:
19681966
cols = {prefix: cols[prefix] for prefix in mods}
19691967

19701968
derived_name_count = Counter(
1971-
[col["derived_name"] for modcols in cols.values() for col in modcols]
1969+
[col.derived_name for modcols in cols.values() for col in modcols]
19721970
)
19731971

19741972
# - axis == self.axis
@@ -2006,24 +2004,24 @@ def _pull_attr(
20062004
mod_map = attrmap[m].ravel()
20072005
mask = mod_map > 0
20082006

2009-
mod_df = getattr(mod, attr)[[col["derived_name"] for col in modcols]]
2007+
mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]]
20102008
if drop:
20112009
getattr(mod, attr).drop(columns=mod_df.columns, inplace=True)
20122010

20132011
mod_df.rename(
20142012
columns={
2015-
col["derived_name"]: col["name"]
2013+
col.derived_name: col.name
20162014
for col in modcols
20172015
if not (
20182016
(
20192017
join_common
2020-
and col["class"] == "common"
2018+
and col.klass == "common"
20212019
or join_nonunique
2022-
and col["class"] == "nonunique"
2020+
and col.klass == "nonunique"
20232021
or not prefix_unique
2024-
and col["class"] == "unique"
2022+
and col.klass == "unique"
20252023
)
2026-
and derived_name_count[col["derived_name"]] == col["count"]
2024+
and derived_name_count[col.derived_name] == col.count
20272025
)
20282026
},
20292027
inplace=True,
@@ -2244,7 +2242,10 @@ def _push_attr(
22442242
if only_drop:
22452243
drop = True
22462244

2247-
cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys())
2245+
cols = [
2246+
MetadataColumn(allowed_prefixes=self.mod.keys(), name=name)
2247+
for name in getattr(self, attr).columns
2248+
]
22482249

22492250
if columns is not None:
22502251
for k, v in {"common": common, "prefixed": prefixed}.items():
@@ -2261,23 +2262,23 @@ def _push_attr(
22612262
cols = [
22622263
col
22632264
for col in cols
2264-
if (col["name"] in columns or col["derived_name"] in columns)
2265-
and (col["prefix"] == "" or mods is not None and col["prefix"] in mods)
2265+
if (col.name in columns or col.derived_name in columns)
2266+
and (col.prefix is None or mods is not None and col.prefix in mods)
22662267
]
22672268
else:
22682269
if common is None:
22692270
common = True
22702271
if prefixed is None:
22712272
prefixed = True
22722273

2273-
selector = {"common": common, "prefixed": prefixed}
2274+
selector = {"common": common, "unknown": prefixed}
22742275

2275-
cols = [col for col in cols if selector[col["class"]]]
2276+
cols = [col for col in cols if selector[col.klass]]
22762277

22772278
if len(cols) == 0:
22782279
return
22792280

2280-
derived_name_count = Counter([col["derived_name"] for col in cols])
2281+
derived_name_count = Counter([col.derived_name for col in cols])
22812282
for c, count in derived_name_count.items():
22822283
# if count > 1, there are both colname and modname:colname present
22832284
if count > 1 and c in getattr(self, attr).columns:
@@ -2299,9 +2300,9 @@ def _push_attr(
22992300
mask = mod_map != 0
23002301
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
23012302

2302-
mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"]
2303-
df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]]
2304-
df.columns = [col["derived_name"] for col in mod_cols]
2303+
mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"]
2304+
df = getattr(self, attr)[mask].loc[:, [col.name for col in mod_cols]]
2305+
df.columns = [col.derived_name for col in mod_cols]
23052306

23062307
df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr))
23072308

@@ -2316,7 +2317,7 @@ def _push_attr(
23162317

23172318
if drop:
23182319
for col in cols:
2319-
getattr(self, attr).drop(col["name"], axis=1, inplace=True)
2320+
getattr(self, attr).drop(col.name, axis=1, inplace=True)
23202321

23212322
def push_obs(
23222323
self,

src/mudata/_core/utils.py

Lines changed: 57 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import Counter
22
from collections.abc import Mapping, Sequence
3-
from typing import TypeVar
3+
from typing import Literal, TypeVar
44

55
import numpy as np
66
import pandas as pd
@@ -38,7 +38,56 @@ def _maybe_coerce_to_boolean(df: T) -> T:
3838
return df
3939

4040

41-
def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[dict[str, str]]]:
41+
class MetadataColumn:
42+
__slots__ = ("prefix", "derived_name", "count", "_allowed_prefixes")
43+
44+
def __init__(
45+
self,
46+
*,
47+
allowed_prefixes: Sequence[str],
48+
prefix: str | None = None,
49+
name: str | None = None,
50+
count: int = 0,
51+
):
52+
self._allowed_prefixes = allowed_prefixes
53+
if prefix is None:
54+
self.name = name
55+
else:
56+
self.prefix = prefix
57+
self.derived_name = name
58+
self.count = count
59+
60+
@property
61+
def name(self) -> str:
62+
if self.prefix is not None:
63+
return f"{self.prefix}:{self.derived_name}"
64+
else:
65+
return self.derived_name
66+
67+
@name.setter
68+
def name(self, new_name):
69+
if (
70+
len(name_split := new_name.split(":", 1)) < 2
71+
or name_split[0] not in self._allowed_prefixes
72+
):
73+
self.prefix = None
74+
self.derived_name = new_name
75+
else:
76+
self.prefix, self.derived_name = name_split
77+
78+
@property
79+
def klass(self) -> Literal["common", "unique", "nonunique", "unknown"]:
80+
if self.prefix is None or self.count == len(self._allowed_prefixes):
81+
return "common"
82+
elif self.count == 1:
83+
return "unique"
84+
elif self.count > 0:
85+
return "nonunique"
86+
else:
87+
return "unknown"
88+
89+
90+
def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[MetadataColumn]]:
4291
"""
4392
Classify names into common, non-unique, and unique
4493
w.r.t. to the list of prefixes.
@@ -50,72 +99,21 @@ def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list
5099
- Unique columns are prefixed by modality names,
51100
and there is only one modality prefix
52101
for a column with a certain name.
53-
54-
E.g. {"mod1": ["annotation", "unique"], "mod2": ["annotation"]} will be classified
55-
into {"mod1": [{"name": "mod1:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"},
56-
{"name": "mod1:unique", "derived_name": "unique", "count": 1, "class": "unique"}}],
57-
"mod2": [{"name": "mod2:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"}],
58-
}
59102
"""
60-
n_mod = len(names)
61-
res: dict[str, list[dict[str, str]]] = {}
103+
res: dict[str, list[MetadataColumn]] = {}
62104

63105
derived_name_counts = Counter()
64-
for prefix, names in names.items():
106+
for prefix, pnames in names.items():
65107
cres = []
66-
for name in names:
67-
cres.append(
68-
{
69-
"name": f"{prefix}:{name}",
70-
"derived_name": name,
71-
}
72-
)
108+
for name in pnames:
109+
cres.append(MetadataColumn(allowed_prefixes=names.keys(), prefix=prefix, name=name))
73110
derived_name_counts[name] += 1
74111
res[prefix] = cres
75112

76113
for prefix, names in res.items():
77114
for name_res in names:
78-
count = derived_name_counts[name_res["derived_name"]]
79-
name_res["count"] = count
80-
name_res["class"] = (
81-
"common" if count == n_mod else "unique" if count == 1 else "nonunique"
82-
)
83-
84-
return res
85-
86-
87-
def _classify_prefixed_columns(
88-
names: Sequence[str], prefixes: Sequence[str]
89-
) -> Sequence[dict[str, str]]:
90-
"""
91-
Classify names into common and prefixed
92-
w.r.t. to the list of prefixes.
93-
94-
- Common columns do not have modality prefixes.
95-
- Prefixed columns are prefixed by modality names.
96-
97-
E.g. ["global", "mod1:annotation", "mod2:annotation", "mod1:unique"] will be classified
98-
into [
99-
{"name": "global", "prefix": "", "derived_name": "global", "class": "common"},
100-
{"name": "mod1:annotation", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"},
101-
{"name": "mod2:annotation", "prefix": "mod2", "derived_name": "annotation", "class": "prefixed"},
102-
{"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"},
103-
]
104-
"""
105-
res: list[dict[str, str]] = []
106-
107-
for name in names:
108-
if len(name_split := name.split(":", 1)) < 2 or name_split[0] not in prefixes:
109-
res.append({"name": name, "prefix": "", "derived_name": name, "class": "common"})
110-
else:
111-
res.append(
112-
{
113-
"name": name,
114-
"prefix": name_split[0],
115-
"derived_name": name_split[1],
116-
"class": "prefixed",
117-
}
118-
)
115+
count = derived_name_counts[name_res.derived_name]
116+
name_res.count = count
119117

120118
return res
121119

0 commit comments

Comments
 (0)