2828from .repr import MUDATA_CSS , block_matrix , details_block_table
2929from .utils import (
3030 MetadataColumn ,
31- _classify_attr_columns ,
3231 _make_index_unique ,
3332 _maybe_coerce_to_bool ,
3433 _maybe_coerce_to_boolean ,
@@ -1923,9 +1922,29 @@ def _pull_attr(
19231922 if only_drop :
19241923 drop = True
19251924
1926- cols = _classify_attr_columns (
1927- {modname : getattr (mod , attr ).columns for modname , mod in self .mod .items ()}
1928- )
1925+ cols : dict [str , list [MetadataColumn ]] = {}
1926+
1927+ # get all columns from all modalities and count how many times each column is present
1928+ derived_name_counts = Counter ()
1929+ for prefix , mod in self .mod .items ():
1930+ modcols = getattr (mod , attr ).columns
1931+ ccols = []
1932+ for name in modcols :
1933+ ccols .append (
1934+ MetadataColumn (
1935+ allowed_prefixes = self .mod .keys (),
1936+ prefix = prefix ,
1937+ name = name ,
1938+ strip_prefix = False ,
1939+ )
1940+ )
1941+ derived_name_counts [name ] += 1
1942+ cols [prefix ] = ccols
1943+
1944+ for prefix , modcols in cols .items ():
1945+ for col in modcols :
1946+ count = derived_name_counts [col .derived_name ]
1947+ col .count = count # this is important to classify columns
19291948
19301949 if columns is not None :
19311950 for k , v in {"common" : common , "nonunique" : nonunique , "unique" : unique }.items ():
@@ -1936,8 +1955,7 @@ def _pull_attr(
19361955 stacklevel = 2 ,
19371956 )
19381957
1939- # - modname1:column -> [modname1:column]
1940- # - column -> [modname1:column, modname2:column, ...]
1958+ # keep only requested columns
19411959 cols = {
19421960 prefix : [
19431961 col for col in modcols if col .name in columns or col .derived_name in columns
@@ -1956,15 +1974,18 @@ def _pull_attr(
19561974 if unique is None :
19571975 unique = True
19581976
1977+ # filter columns by class, keep only those that were requested
19591978 selector = {"common" : common , "nonunique" : nonunique , "unique" : unique }
19601979 cols = {
19611980 prefix : [col for col in modcols if selector [col .klass ]]
19621981 for prefix , modcols in cols .items ()
19631982 }
19641983
1984+ # filter columns, keep only requested modalities
19651985 if mods is not None :
19661986 cols = {prefix : cols [prefix ] for prefix in mods }
19671987
1988+ # count final filtered column names, required later to decide whether to prefix a column with its source modality
19681989 derived_name_count = Counter (
19691990 [col .derived_name for modcols in cols .values () for col in modcols ]
19701991 )
@@ -2008,6 +2029,8 @@ def _pull_attr(
20082029 if drop :
20092030 getattr (mod , attr ).drop (columns = mod_df .columns , inplace = True )
20102031
2032+ # prepend modality prefix to column names if requested via arguments and there are no skipped modalities with
2033+ # the same column name (prefixing those columns may cause problems with future pulls or pushes)
20112034 mod_df .rename (
20122035 columns = {
20132036 col .derived_name : col .name
@@ -2027,6 +2050,7 @@ def _pull_attr(
20272050 inplace = True ,
20282051 )
20292052
2053+ # reorder modality DF to conform to global order
20302054 mod_df = (
20312055 _maybe_coerce_to_boolean (mod_df )
20322056 .iloc [mod_map [mask ] - 1 ]
@@ -2242,6 +2266,7 @@ def _push_attr(
22422266 if only_drop :
22432267 drop = True
22442268
2269+ # get all global columns
22452270 cols = [
22462271 MetadataColumn (allowed_prefixes = self .mod .keys (), name = name )
22472272 for name in getattr (self , attr ).columns
@@ -2256,9 +2281,7 @@ def _push_attr(
22562281 stacklevel = 2 ,
22572282 )
22582283
2259- # - modname1:column -> [modname1:column]
2260- # - column -> [modname1:column, modname2:column, ...]
2261- # preemptively drop columns from other modalities
2284+ # keep only requested columns
22622285 cols = [
22632286 col
22642287 for col in cols
@@ -2271,8 +2294,8 @@ def _push_attr(
22712294 if prefixed is None :
22722295 prefixed = True
22732296
2297+ # filter columns by class, keep only those that were requested
22742298 selector = {"common" : common , "unknown" : prefixed }
2275-
22762299 cols = [col for col in cols if selector [col .klass ]]
22772300
22782301 if len (cols ) == 0 :
@@ -2290,20 +2313,22 @@ def _push_attr(
22902313 )
22912314
22922315 attrmap = getattr (self , f"{ attr } map" )
2293- _n_attr = self .n_vars if attr == "var" else self .n_obs
2294-
22952316 for m , mod in self .mod .items ():
22962317 if mods is not None and m not in mods :
22972318 continue
22982319
22992320 mod_map = attrmap [m ].ravel ()
2300- mask = mod_map != 0
2301- mod_n_attr = mod .n_vars if attr == "var " else mod .n_obs
2321+ mask = mod_map > 0
2322+ mod_n_attr = mod .n_obs if attr == "obs " else mod .n_vars
23022323
2324+ # get all common and modality-specific columns for the current modality
23032325 mod_cols = [col for col in cols if col .prefix == m or col .klass == "common" ]
23042326 df = getattr (self , attr )[mask ].loc [:, [col .name for col in mod_cols ]]
2327+
2328+ # strip modality prefix where necessary
23052329 df .columns = [col .derived_name for col in mod_cols ]
23062330
2331+ # reorder global DF to conform to modality order
23072332 df = df .iloc [np .argsort (mod_map [mask ])].set_index (np .arange (mod_n_attr ))
23082333
23092334 if not only_drop :
0 commit comments