1313import dask
1414import dask .array
1515
16+ from .external import SparseXArrayDataArray , SparseXArrayDataSet
17+
1618try :
1719 import anndata
1820except ImportError :
@@ -31,17 +33,18 @@ def fetch_X(idx):
3133 if idx .size == 1 :
3234 retval = np .squeeze (retval , axis = 0 )
3335
34- return retval .astype (np .float32 )
36+ return retval .astype (np .float64 )
3537
3638 delayed_fetch = dask .delayed (fetch_X , pure = True )
3739 X = [
3840 dask .array .from_delayed (
3941 delayed_fetch (idx ),
4042 shape = (num_features ,),
41- dtype = np .float32
43+ dtype = np .float64
4244 ) for idx in range (num_observations )
4345 ]
44- X = xr .DataArray (dask .array .stack (X ), dims = dims )
46+
47+ X = data
4548
4649 # currently broken:
4750 # X = data.X
@@ -53,9 +56,9 @@ def fetch_X(idx):
5356
5457
5558def xarray_from_data (
56- data : Union [anndata .AnnData , xr .DataArray , xr .Dataset , np .ndarray ],
59+ data : Union [anndata .AnnData , anndata . base . Raw , xr .DataArray , xr .Dataset , np .ndarray , scipy . sparse . csr_matrix ],
5760 dims : Union [Tuple , List ] = ("observations" , "features" )
58- ) -> xr . DataArray :
61+ ):
5962 """
6063 Parse any array-like object, xr.DataArray, xr.Dataset or anndata.Anndata and return a xarray containing
6164 the observations.
@@ -64,26 +67,52 @@ def xarray_from_data(
6467 :param dims: tuple or list with two strings. Specifies the names of the xarray dimensions.
6568 :return: xr.DataArray of shape `dims`
6669 """
67- if anndata is not None and isinstance (data , anndata .AnnData ):
70+ if anndata is not None and (isinstance (data , anndata .AnnData ) or isinstance (data , anndata .base .Raw )):
71+ # Anndata.raw does not have obs_names.
72+ if isinstance (data , anndata .AnnData ):
73+ obs_names = np .asarray (data .obs_names )
74+ else :
75+ obs_names = ["obs_" + str (i ) for i in range (data .X .shape [0 ])]
76+
6877 if scipy .sparse .issparse (data .X ):
69- X = _sparse_to_xarray (data .X , dims = dims )
70- X .coords [dims [0 ]] = np .asarray (data .obs_names )
71- X .coords [dims [1 ]] = np .asarray (data .var_names )
78+ # X = _sparse_to_xarray(data.X, dims=dims)
79+ # X.coords[dims[0]] = np.asarray(data.obs_names)
80+ # X.coords[dims[1]] = np.asarray(data.var_names)
81+ X = SparseXArrayDataSet (
82+ X = data .X ,
83+ obs_names = np .asarray (obs_names ),
84+ feature_names = np .asarray (data .var_names ),
85+ dims = dims
86+ )
7287 else :
73- X = data .X
74- X = xr .DataArray (X , dims = dims , coords = {
75- dims [0 ]: np .asarray (data .obs_names ),
76- dims [1 ]: np .asarray (data .var_names ),
77- })
88+ X = xr .DataArray (
89+ data .X ,
90+ dims = dims ,
91+ coords = {
92+ dims [0 ]: np .asarray (obs_names ),
93+ dims [1 ]: np .asarray (data .var_names ),
94+ }
95+ )
7896 elif isinstance (data , xr .Dataset ):
7997 X : xr .DataArray = data ["X" ]
8098 elif isinstance (data , xr .DataArray ):
8199 X = data
100+ elif isinstance (data , SparseXArrayDataSet ):
101+ X = data
102+ elif scipy .sparse .issparse (data ):
103+ # X = _sparse_to_xarray(data, dims=dims)
104+ # X.coords[dims[0]] = np.asarray(data.obs_names)
105+ # X.coords[dims[1]] = np.asarray(data.var_names)
106+ X = SparseXArrayDataSet (
107+ X = data ,
108+ obs_names = None ,
109+ feature_names = None ,
110+ dims = dims
111+ )
112+ elif isinstance (data , np .ndarray ):
113+ X = xr .DataArray (data , dims = dims )
82114 else :
83- if scipy .sparse .issparse (data ):
84- X = _sparse_to_xarray (data , dims = dims )
85- else :
86- X = xr .DataArray (data , dims = dims )
115+ raise ValueError ("batchglm data parsing: data format %s not recognized" % type (data ))
87116
88117 return X
89118
@@ -537,7 +566,7 @@ def parse_constraints(
537566 Parse constraint matrix into xarray.
538567
539568 :param dmat: Design matrix.
540- :param a constraint matrix
569+ :param constraints: a constraint matrix
541570 :return: constraint matrix in xarray format
542571 """
543572 constraints_ar = xr .DataArray (
0 commit comments