-
Notifications
You must be signed in to change notification settings - Fork 21
fixes for push_attr/pull_attr #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
3a4b2e1
8647ddd
2b26ebf
21b2e83
8b8d1c1
f4d5eca
0d646ed
717c362
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,8 @@ | |
| from .file_backing import MuDataFileManager | ||
| from .repr import MUDATA_CSS, block_matrix, details_block_table | ||
| from .utils import ( | ||
| MetadataColumn, | ||
| _classify_attr_columns, | ||
| _classify_prefixed_columns, | ||
| _make_index_unique, | ||
| _maybe_coerce_to_bool, | ||
| _maybe_coerce_to_boolean, | ||
|
|
@@ -1915,37 +1915,35 @@ def _pull_attr( | |
| if mods is not None: | ||
| if isinstance(mods, str): | ||
| mods = [mods] | ||
| mods = list(dict.fromkeys(mods)) | ||
| if not all(m in self.mod for m in mods): | ||
| raise ValueError("All mods should be present in mdata.mod") | ||
| elif len(mods) == self.n_mod: | ||
| mods = None | ||
| for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): | ||
| assert v is None, f"Cannot use mods with {k}." | ||
|
|
||
| if only_drop: | ||
| drop = True | ||
|
|
||
| cols = _classify_attr_columns( | ||
| np.concatenate( | ||
| [ | ||
| [f"{m}:{val}" for val in getattr(mod, attr).columns.values] | ||
| for m, mod in self.mod.items() | ||
| ] | ||
| ), | ||
| self.mod.keys(), | ||
| {modname: getattr(mod, attr).columns for modname, mod in self.mod.items()} | ||
| ) | ||
|
|
||
| if columns is not None: | ||
| for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): | ||
| assert v is None, f"Cannot use {k} with columns." | ||
| if v is not None: | ||
| warnings.warn( | ||
| f"Both columns and {k} given. Columns take precedence, {k} will be ignored", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would something like this improve readability? (I am not sure we have a consistent policy for formatting in such cases.)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If yes, this is also true for similar warnings in other parts of the PR.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that would be a bit misleading here, since the warning will also be emitted if ? But I'm not sure if that brings the message across that it should just be not passed at all (as in leave the None default). |
||
| RuntimeWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| # - modname1:column -> [modname1:column] | ||
| # - column -> [modname1:column, modname2:column, ...] | ||
| cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] | ||
|
|
||
| if mods is not None: | ||
| cols = [col for col in cols if col["prefix"] in mods] | ||
| cols = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i.e., with prefix_to_cols = cols.filter_by_name_or_derived_name(colums)(I would also advocate changing the name from |
||
| prefix: [ | ||
| col for col in modcols if col.name in columns or col.derived_name in columns | ||
| ] | ||
| for prefix, modcols in cols.items() | ||
| } | ||
|
|
||
| # TODO: Counter for columns in order to track their usage | ||
| # and error out if some columns were not used | ||
|
|
@@ -1959,27 +1957,33 @@ def _pull_attr( | |
| unique = True | ||
|
|
||
| selector = {"common": common, "nonunique": nonunique, "unique": unique} | ||
| cols = { | ||
| prefix: [col for col in modcols if selector[col.klass]] | ||
| for prefix, modcols in cols.items() | ||
| } | ||
|
|
||
| cols = [col for col in cols if selector[col["class"]]] | ||
| if mods is not None: | ||
| cols = {prefix: cols[prefix] for prefix in mods} | ||
|
|
||
| derived_name_count = Counter([col["derived_name"] for col in cols]) | ||
| derived_name_count = Counter( | ||
| [col.derived_name for modcols in cols.values() for col in modcols] | ||
| ) | ||
|
|
||
| # - axis == self.axis | ||
| # e.g. combine var from multiple modalities (with unique vars) | ||
| # e.g. combine obs from multiple modalities (with shared obs) | ||
| # - 1 - axis == self.axis | ||
| # . e.g. combine obs from multiple modalities (with shared obs) | ||
| axis = 0 if attr == "var" else 1 | ||
| # e.g. combine var from multiple modalities (with unique vars) | ||
| axis = 0 if attr == "obs" else 1 | ||
|
|
||
| if 1 - axis == self.axis or self.axis == -1: | ||
| if axis == self.axis or self.axis == -1: | ||
| if join_common or join_nonunique: | ||
| raise ValueError(f"Cannot join columns with the same name for shared {attr}_names.") | ||
|
|
||
| if join_common is None: | ||
| join_common = False | ||
| if attr == "var": | ||
| join_common = self.axis == 0 | ||
| elif attr == "obs": | ||
| if attr == "obs": | ||
| join_common = self.axis == 1 | ||
| else: | ||
| join_common = self.axis == 0 | ||
|
|
||
| if join_nonunique is None: | ||
| join_nonunique = False | ||
|
|
@@ -1995,44 +1999,36 @@ def _pull_attr( | |
| n_attr = self.n_vars if attr == "var" else self.n_obs | ||
|
|
||
| dfs: list[pd.DataFrame] = [] | ||
| for m, mod in self.mod.items(): | ||
| if mods is not None and m not in mods: | ||
| continue | ||
| for m, modcols in cols.items(): | ||
| mod = self.mod[m] | ||
| mod_map = attrmap[m].ravel() | ||
| mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs | ||
| mask = mod_map != 0 | ||
|
|
||
| mod_df = getattr(mod, attr) | ||
| mod_columns = [ | ||
| col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m | ||
| ] | ||
| mod_df = mod_df[mod_df.columns.intersection(mod_columns)] | ||
| mask = mod_map > 0 | ||
|
|
||
| mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]] | ||
| if drop: | ||
| getattr(mod, attr).drop(columns=mod_df.columns, inplace=True) | ||
|
|
||
| # Don't use modname: prefix if columns need to be joined | ||
| if join_common or join_nonunique or (not prefix_unique): | ||
| cols_special = [ | ||
| col["derived_name"] | ||
| for col in cols | ||
| if ( | ||
| (col["class"] == "common") & join_common | ||
| or (col["class"] == "nonunique") & join_nonunique | ||
| or (col["class"] == "unique") & (not prefix_unique) | ||
| mod_df.rename( | ||
| columns={ | ||
| col.derived_name: col.name | ||
| for col in modcols | ||
| if not ( | ||
| ( | ||
| join_common | ||
| and col.klass == "common" | ||
| or join_nonunique | ||
| and col.klass == "nonunique" | ||
| or not prefix_unique | ||
| and col.klass == "unique" | ||
| ) | ||
| and derived_name_count[col.derived_name] == col.count | ||
ilia-kats marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| and col["prefix"] == m | ||
| and derived_name_count[col["derived_name"]] == col["count"] | ||
| ] | ||
| mod_df.columns = [ | ||
| col if col in cols_special else f"{m}:{col}" for col in mod_df.columns | ||
| ] | ||
| else: | ||
| mod_df.columns = [f"{m}:{col}" for col in mod_df.columns] | ||
| }, | ||
| inplace=True, | ||
| ) | ||
|
|
||
| mod_df = ( | ||
| _maybe_coerce_to_boolean(mod_df) | ||
| .set_index(np.arange(mod_n_attr)) | ||
| .iloc[mod_map[mask] - 1] | ||
| .set_index(np.arange(n_attr)[mask]) | ||
| .reindex(np.arange(n_attr)) | ||
|
|
@@ -2242,39 +2238,47 @@ def _push_attr( | |
| raise ValueError("All mods should be present in mdata.mod") | ||
| elif len(mods) == self.n_mod: | ||
| mods = None | ||
| for k, v in {"common": common, "prefixed": prefixed}.items(): | ||
| assert v is None, f"Cannot use mods with {k}." | ||
|
|
||
| if only_drop: | ||
| drop = True | ||
|
|
||
| cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys()) | ||
| cols = [ | ||
| MetadataColumn(allowed_prefixes=self.mod.keys(), name=name) | ||
| for name in getattr(self, attr).columns | ||
| ] | ||
|
|
||
| if columns is not None: | ||
| for k, v in {"common": common, "prefixed": prefixed}.items(): | ||
| assert v is None, f"Cannot use columns with {k}." | ||
| if v: | ||
| warnings.warn( | ||
| f"Both columns and {k} given. Columns take precedence, {k} will be ignored", | ||
| RuntimeWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| # - modname1:column -> [modname1:column] | ||
| # - column -> [modname1:column, modname2:column, ...] | ||
| cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] | ||
|
|
||
| # preemptively drop columns from other modalities | ||
| if mods is not None: | ||
| cols = [col for col in cols if col["prefix"] in mods or col["prefix"] == ""] | ||
| cols = [ | ||
| col | ||
| for col in cols | ||
| if (col.name in columns or col.derived_name in columns) | ||
| and (col.prefix is None or mods is not None and col.prefix in mods) | ||
| ] | ||
| else: | ||
| if common is None: | ||
| common = True | ||
| if prefixed is None: | ||
| prefixed = True | ||
|
|
||
| selector = {"common": common, "prefixed": prefixed} | ||
| selector = {"common": common, "unknown": prefixed} | ||
|
|
||
| cols = [col for col in cols if selector[col["class"]]] | ||
| cols = [col for col in cols if selector[col.klass]] | ||
|
|
||
| if len(cols) == 0: | ||
| return | ||
|
|
||
| derived_name_count = Counter([col["derived_name"] for col in cols]) | ||
| derived_name_count = Counter([col.derived_name for col in cols]) | ||
| for c, count in derived_name_count.items(): | ||
| # if count > 1, there are both colname and modname:colname present | ||
| if count > 1 and c in getattr(self, attr).columns: | ||
|
|
@@ -2292,19 +2296,15 @@ def _push_attr( | |
| if mods is not None and m not in mods: | ||
| continue | ||
|
|
||
| mod_map = attrmap[m] | ||
| mod_map = attrmap[m].ravel() | ||
| mask = mod_map != 0 | ||
| mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs | ||
|
|
||
| mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"] | ||
| df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]] | ||
| df.columns = [col["derived_name"] for col in mod_cols] | ||
| mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"] | ||
| df = getattr(self, attr)[mask].loc[:, [col.name for col in mod_cols]] | ||
ilia-kats marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| df.columns = [col.derived_name for col in mod_cols] | ||
|
|
||
| df = ( | ||
| df.set_index(np.arange(mod_n_attr)) | ||
| .iloc[mod_map[mask] - 1] | ||
| .set_index(np.arange(mod_n_attr)) | ||
| ) | ||
| df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr)) | ||
|
|
||
| if not only_drop: | ||
| # TODO: _maybe_coerce_to_bool | ||
|
|
@@ -2317,7 +2317,7 @@ def _push_attr( | |
|
|
||
| if drop: | ||
| for col in cols: | ||
| getattr(self, attr).drop(col["name"], axis=1, inplace=True) | ||
| getattr(self, attr).drop(col.name, axis=1, inplace=True) | ||
|
|
||
| def push_obs( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
colsshould be a class, not just a dictionary, with methods to encapsulate the below iterations (and it's sub-dictionary value as well to handle themodcolslogic)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally I would agree with you, but this used at exactly one place in the entire codebase, so I think a class is a bit overkill at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I would consider making it a named tuple for performance reasons.