Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 76 additions & 76 deletions src/mudata/_core/mudata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from .file_backing import MuDataFileManager
from .repr import MUDATA_CSS, block_matrix, details_block_table
from .utils import (
MetadataColumn,
_classify_attr_columns,
_classify_prefixed_columns,
_make_index_unique,
_maybe_coerce_to_bool,
_maybe_coerce_to_boolean,
Expand Down Expand Up @@ -1915,37 +1915,35 @@ def _pull_attr(
if mods is not None:
if isinstance(mods, str):
mods = [mods]
mods = list(dict.fromkeys(mods))
if not all(m in self.mod for m in mods):
raise ValueError("All mods should be present in mdata.mod")
elif len(mods) == self.n_mod:
mods = None
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
assert v is None, f"Cannot use mods with {k}."

if only_drop:
drop = True

cols = _classify_attr_columns(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cols should be a class, not just a dictionary, with methods to encapsulate the below iterations (and it's sub-dictionary value as well to handle the modcols logic)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I would agree with you, but this used at exactly one place in the entire codebase, so I think a class is a bit overkill at this point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I would consider making it a named tuple for performance reasons.

np.concatenate(
[
[f"{m}:{val}" for val in getattr(mod, attr).columns.values]
for m, mod in self.mod.items()
]
),
self.mod.keys(),
{modname: getattr(mod, attr).columns for modname, mod in self.mod.items()}
)

if columns is not None:
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
assert v is None, f"Cannot use {k} with columns."
if v is not None:
warnings.warn(
f"Both columns and {k} given. Columns take precedence, {k} will be ignored",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this improve readability? (I am not sure we have a consistent policy for formatting in such cases.)

Both `columns=...` and `{k}=True` were given. <...>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If yes, this is also true for similar warnings in other parts of the PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would be a bit misleading here, since the warning will also be emitted if k=False or any other value which is not the None default. Perhaps something like

Both `columns=...` and `{k}={locals()[k]}` were given...

? But I'm not sure if that brings the message across that it should just be not passed at all (as in leave the None default).

RuntimeWarning,
stacklevel=2,
)

# - modname1:column -> [modname1:column]
# - column -> [modname1:column, modname2:column, ...]
cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns]

if mods is not None:
cols = [col for col in cols if col["prefix"] in mods]
cols = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e., with cols as a class this could be

prefix_to_cols = cols.filter_by_name_or_derived_name(colums)

(I would also advocate changing the name from cols to prefix_to_cols to avoid confusion with columns)

prefix: [
col for col in modcols if col.name in columns or col.derived_name in columns
]
for prefix, modcols in cols.items()
}

# TODO: Counter for columns in order to track their usage
# and error out if some columns were not used
Expand All @@ -1959,27 +1957,33 @@ def _pull_attr(
unique = True

selector = {"common": common, "nonunique": nonunique, "unique": unique}
cols = {
prefix: [col for col in modcols if selector[col.klass]]
for prefix, modcols in cols.items()
}

cols = [col for col in cols if selector[col["class"]]]
if mods is not None:
cols = {prefix: cols[prefix] for prefix in mods}

derived_name_count = Counter([col["derived_name"] for col in cols])
derived_name_count = Counter(
[col.derived_name for modcols in cols.values() for col in modcols]
)

# - axis == self.axis
# e.g. combine var from multiple modalities (with unique vars)
# e.g. combine obs from multiple modalities (with shared obs)
# - 1 - axis == self.axis
# . e.g. combine obs from multiple modalities (with shared obs)
axis = 0 if attr == "var" else 1
# e.g. combine var from multiple modalities (with unique vars)
axis = 0 if attr == "obs" else 1

if 1 - axis == self.axis or self.axis == -1:
if axis == self.axis or self.axis == -1:
if join_common or join_nonunique:
raise ValueError(f"Cannot join columns with the same name for shared {attr}_names.")

if join_common is None:
join_common = False
if attr == "var":
join_common = self.axis == 0
elif attr == "obs":
if attr == "obs":
join_common = self.axis == 1
else:
join_common = self.axis == 0

if join_nonunique is None:
join_nonunique = False
Expand All @@ -1995,44 +1999,36 @@ def _pull_attr(
n_attr = self.n_vars if attr == "var" else self.n_obs

dfs: list[pd.DataFrame] = []
for m, mod in self.mod.items():
if mods is not None and m not in mods:
continue
for m, modcols in cols.items():
mod = self.mod[m]
mod_map = attrmap[m].ravel()
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
mask = mod_map != 0

mod_df = getattr(mod, attr)
mod_columns = [
col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m
]
mod_df = mod_df[mod_df.columns.intersection(mod_columns)]
mask = mod_map > 0

mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]]
if drop:
getattr(mod, attr).drop(columns=mod_df.columns, inplace=True)

# Don't use modname: prefix if columns need to be joined
if join_common or join_nonunique or (not prefix_unique):
cols_special = [
col["derived_name"]
for col in cols
if (
(col["class"] == "common") & join_common
or (col["class"] == "nonunique") & join_nonunique
or (col["class"] == "unique") & (not prefix_unique)
mod_df.rename(
columns={
col.derived_name: col.name
for col in modcols
if not (
(
join_common
and col.klass == "common"
or join_nonunique
and col.klass == "nonunique"
or not prefix_unique
and col.klass == "unique"
)
and derived_name_count[col.derived_name] == col.count
)
and col["prefix"] == m
and derived_name_count[col["derived_name"]] == col["count"]
]
mod_df.columns = [
col if col in cols_special else f"{m}:{col}" for col in mod_df.columns
]
else:
mod_df.columns = [f"{m}:{col}" for col in mod_df.columns]
},
inplace=True,
)

mod_df = (
_maybe_coerce_to_boolean(mod_df)
.set_index(np.arange(mod_n_attr))
.iloc[mod_map[mask] - 1]
.set_index(np.arange(n_attr)[mask])
.reindex(np.arange(n_attr))
Expand Down Expand Up @@ -2242,39 +2238,47 @@ def _push_attr(
raise ValueError("All mods should be present in mdata.mod")
elif len(mods) == self.n_mod:
mods = None
for k, v in {"common": common, "prefixed": prefixed}.items():
assert v is None, f"Cannot use mods with {k}."

if only_drop:
drop = True

cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys())
cols = [
MetadataColumn(allowed_prefixes=self.mod.keys(), name=name)
for name in getattr(self, attr).columns
]

if columns is not None:
for k, v in {"common": common, "prefixed": prefixed}.items():
assert v is None, f"Cannot use columns with {k}."
if v:
warnings.warn(
f"Both columns and {k} given. Columns take precedence, {k} will be ignored",
RuntimeWarning,
stacklevel=2,
)

# - modname1:column -> [modname1:column]
# - column -> [modname1:column, modname2:column, ...]
cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns]

# preemptively drop columns from other modalities
if mods is not None:
cols = [col for col in cols if col["prefix"] in mods or col["prefix"] == ""]
cols = [
col
for col in cols
if (col.name in columns or col.derived_name in columns)
and (col.prefix is None or mods is not None and col.prefix in mods)
]
else:
if common is None:
common = True
if prefixed is None:
prefixed = True

selector = {"common": common, "prefixed": prefixed}
selector = {"common": common, "unknown": prefixed}

cols = [col for col in cols if selector[col["class"]]]
cols = [col for col in cols if selector[col.klass]]

if len(cols) == 0:
return

derived_name_count = Counter([col["derived_name"] for col in cols])
derived_name_count = Counter([col.derived_name for col in cols])
for c, count in derived_name_count.items():
# if count > 1, there are both colname and modname:colname present
if count > 1 and c in getattr(self, attr).columns:
Expand All @@ -2292,19 +2296,15 @@ def _push_attr(
if mods is not None and m not in mods:
continue

mod_map = attrmap[m]
mod_map = attrmap[m].ravel()
mask = mod_map != 0
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs

mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"]
df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]]
df.columns = [col["derived_name"] for col in mod_cols]
mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"]
df = getattr(self, attr)[mask].loc[:, [col.name for col in mod_cols]]
df.columns = [col.derived_name for col in mod_cols]

df = (
df.set_index(np.arange(mod_n_attr))
.iloc[mod_map[mask] - 1]
.set_index(np.arange(mod_n_attr))
)
df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr))

if not only_drop:
# TODO: _maybe_coerce_to_bool
Expand All @@ -2317,7 +2317,7 @@ def _push_attr(

if drop:
for col in cols:
getattr(self, attr).drop(col["name"], axis=1, inplace=True)
getattr(self, attr).drop(col.name, axis=1, inplace=True)

def push_obs(
self,
Expand Down
Loading