@@ -1920,32 +1920,33 @@ def _pull_attr(
19201920 raise ValueError ("All mods should be present in mdata.mod" )
19211921 elif len (mods ) == self .n_mod :
19221922 mods = None
1923- for k , v in {"common" : common , "nonunique" : nonunique , "unique" : unique }.items ():
1924- assert v is None , f"Cannot use mods with { k } ."
19251923
19261924 if only_drop :
19271925 drop = True
19281926
19291927 cols = _classify_attr_columns (
1930- np .concatenate (
1931- [
1932- [f"{ m } :{ val } " for val in getattr (mod , attr ).columns .values ]
1933- for m , mod in self .mod .items ()
1934- ]
1935- ),
1936- self .mod .keys (),
1928+ {modname : getattr (mod , attr ).columns for modname , mod in self .mod .items ()}
19371929 )
19381930
19391931 if columns is not None :
19401932 for k , v in {"common" : common , "nonunique" : nonunique , "unique" : unique }.items ():
1941- assert v is None , f"Cannot use { k } with columns."
1933+ if v is not None :
1934+ warnings .warn (
1935+ f"Both columns and { k } given. Columns take precedence, { k } will be ignored" ,
1936+ RuntimeWarning ,
1937+ stacklevel = 2 ,
1938+ )
19421939
19431940 # - modname1:column -> [modname1:column]
19441941 # - column -> [modname1:column, modname2:column, ...]
1945- cols = [col for col in cols if col ["name" ] in columns or col ["derived_name" ] in columns ]
1946-
1947- if mods is not None :
1948- cols = [col for col in cols if col ["prefix" ] in mods ]
1942+ cols = {
1943+ prefix : [
1944+ col
1945+ for col in modcols
1946+ if col ["name" ] in columns or col ["derived_name" ] in columns
1947+ ]
1948+ for prefix , modcols in cols .items ()
1949+ }
19491950
19501951 # TODO: Counter for columns in order to track their usage
19511952 # and error out if some columns were not used
@@ -1959,10 +1960,17 @@ def _pull_attr(
19591960 unique = True
19601961
19611962 selector = {"common" : common , "nonunique" : nonunique , "unique" : unique }
1963+ cols = {
1964+ prefix : [col for col in modcols if selector [col ["class" ]]]
1965+ for prefix , modcols in cols .items ()
1966+ }
19621967
1963- cols = [col for col in cols if selector [col ["class" ]]]
1968+ if mods is not None :
1969+ cols = {prefix : cols [prefix ] for prefix in mods }
19641970
1965- derived_name_count = Counter ([col ["derived_name" ] for col in cols ])
1971+ derived_name_count = Counter (
1972+ [col ["derived_name" ] for modcols in cols .values () for col in modcols ]
1973+ )
19661974
19671975 # - axis == self.axis
19681976 # e.g. combine var from multiple modalities (with unique vars)
@@ -1995,44 +2003,36 @@ def _pull_attr(
19952003 n_attr = self .n_vars if attr == "var" else self .n_obs
19962004
19972005 dfs : list [pd .DataFrame ] = []
1998- for m , mod in self .mod .items ():
1999- if mods is not None and m not in mods :
2000- continue
2006+ for m , modcols in cols .items ():
2007+ mod = self .mod [m ]
20012008 mod_map = attrmap [m ].ravel ()
2002- mod_n_attr = mod .n_vars if attr == "var" else mod .n_obs
2003- mask = mod_map != 0
2004-
2005- mod_df = getattr (mod , attr )
2006- mod_columns = [
2007- col ["derived_name" ] for col in cols if col ["prefix" ] == "" or col ["prefix" ] == m
2008- ]
2009- mod_df = mod_df [mod_df .columns .intersection (mod_columns )]
2009+ mask = mod_map > 0
20102010
2011+ mod_df = getattr (mod , attr )[[col ["derived_name" ] for col in modcols ]]
20112012 if drop :
20122013 getattr (mod , attr ).drop (columns = mod_df .columns , inplace = True )
20132014
2014- # Don't use modname: prefix if columns need to be joined
2015- if join_common or join_nonunique or (not prefix_unique ):
2016- cols_special = [
2017- col ["derived_name" ]
2018- for col in cols
2019- if (
2020- (col ["class" ] == "common" ) & join_common
2021- or (col ["class" ] == "nonunique" ) & join_nonunique
2022- or (col ["class" ] == "unique" ) & (not prefix_unique )
2015+ mod_df .rename (
2016+ columns = {
2017+ col ["derived_name" ]: col ["name" ]
2018+ for col in modcols
2019+ if not (
2020+ (
2021+ join_common
2022+ and col ["class" ] == "common"
2023+ or join_nonunique
2024+ and col ["class" ] == "nonunique"
2025+ or not prefix_unique
2026+ and col ["class" ] == "unique"
2027+ )
2028+ and derived_name_count [col ["derived_name" ]] == col ["count" ]
20232029 )
2024- and col ["prefix" ] == m
2025- and derived_name_count [col ["derived_name" ]] == col ["count" ]
2026- ]
2027- mod_df .columns = [
2028- col if col in cols_special else f"{ m } :{ col } " for col in mod_df .columns
2029- ]
2030- else :
2031- mod_df .columns = [f"{ m } :{ col } " for col in mod_df .columns ]
2030+ },
2031+ inplace = True ,
2032+ )
20322033
20332034 mod_df = (
20342035 _maybe_coerce_to_boolean (mod_df )
2035- .set_index (np .arange (mod_n_attr ))
20362036 .iloc [mod_map [mask ] - 1 ]
20372037 .set_index (np .arange (n_attr )[mask ])
20382038 .reindex (np .arange (n_attr ))
0 commit comments