Skip to content

Commit ab3d5a8

Browse files
committed
fix shared memory size issues
1 parent ff3f823 commit ab3d5a8

File tree

6 files changed

+151
-117
lines changed

6 files changed

+151
-117
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ __pycache__/
4242
.vscode/
4343
.cursor/
4444
.claude/
45-
Claude.md
45+
CLAUDE.md
4646

4747
# tmp_scripts
4848
tmp_scripts/

src/rapids_singlecell/pertpy_gpu/_distance.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ class Distance:
2626
2727
Parameters
2828
----------
29-
metric : str
29+
metric
3030
Distance metric. Currently supported: 'edistance' (energy distance).
31-
layer_key : str | None
31+
layer_key
3232
Key in adata.layers for cell data. Mutually exclusive with 'obsm_key'.
33-
obsm_key : str | None
33+
obsm_key
3434
Key in adata.obsm for embeddings. Mutually exclusive with 'layer_key'.
3535
Defaults to 'X_pca' if neither is specified.
3636
@@ -51,12 +51,12 @@ def __init__(
5151
obsm_key: str | None = None,
5252
):
5353
"""Initialize Distance calculator with specified metric."""
54-
if layer_key and obsm_key:
54+
if layer_key is not None and obsm_key is not None:
5555
raise ValueError(
5656
"Cannot use 'layer_key' and 'obsm_key' at the same time.\n"
5757
"Please provide only one of the two keys."
5858
)
59-
if not layer_key and not obsm_key:
59+
if layer_key is None and obsm_key is None:
6060
obsm_key = "X_pca"
6161

6262
self.metric = metric
@@ -93,14 +93,14 @@ def __call__(
9393
9494
Parameters
9595
----------
96-
X : np.ndarray | cp.ndarray
96+
X
9797
First array of shape (n_samples_x, n_features)
98-
Y : np.ndarray | cp.ndarray
98+
Y
9999
Second array of shape (n_samples_y, n_features)
100100
101101
Returns
102102
-------
103-
distance : float
103+
float
104104
Distance between X and Y
105105
106106
Examples
@@ -131,17 +131,17 @@ def pairwise(
131131
132132
Parameters
133133
----------
134-
adata : AnnData
134+
adata
135135
Annotated data matrix
136-
groupby : str
136+
groupby
137137
Key in adata.obs for grouping cells
138-
groups : Sequence[str] | None
138+
groups
139139
Specific groups to compute (if None, use all)
140-
bootstrap : bool
140+
bootstrap
141141
Whether to compute bootstrap variance estimates
142-
n_bootstrap : int
142+
n_bootstrap
143143
Number of bootstrap iterations (if bootstrap=True)
144-
random_state : int
144+
random_state
145145
Random seed for reproducibility
146146
147147
Returns
@@ -180,24 +180,24 @@ def onesided_distances(
180180
181181
Parameters
182182
----------
183-
adata : AnnData
183+
adata
184184
Annotated data matrix
185-
groupby : str
185+
groupby
186186
Key in adata.obs for grouping cells
187-
selected_group : str
187+
selected_group
188188
Reference group to compute distances from
189-
groups : Sequence[str] | None
189+
groups
190190
Specific groups to compute distances to (if None, use all)
191-
bootstrap : bool
191+
bootstrap
192192
Whether to compute bootstrap variance estimates
193-
n_bootstrap : int
193+
n_bootstrap
194194
Number of bootstrap iterations (if bootstrap=True)
195-
random_state : int
195+
random_state
196196
Random seed for reproducibility
197197
198198
Returns
199199
-------
200-
distances : pd.Series | tuple[pd.Series, pd.Series]
200+
distances
201201
Series containing distances from selected_group to all other groups.
202202
If bootstrap=True, returns tuple of (distances, distances_var).
203203
@@ -238,18 +238,18 @@ def bootstrap(
238238
239239
Parameters
240240
----------
241-
X : np.ndarray | cp.ndarray
241+
X
242242
First array of shape (n_samples_x, n_features)
243-
Y : np.ndarray | cp.ndarray
243+
Y
244244
Second array of shape (n_samples_y, n_features)
245-
n_bootstrap : int
245+
n_bootstrap
246246
Number of bootstrap iterations
247-
random_state : int
247+
random_state
248248
Random seed for reproducibility
249249
250250
Returns
251251
-------
252-
result : MeanVar
252+
result
253253
Named tuple containing mean and variance of bootstrapped distances
254254
255255
Examples
@@ -287,22 +287,22 @@ def bootstrap_adata(
287287
288288
Parameters
289289
----------
290-
adata : AnnData
290+
adata
291291
Annotated data matrix
292-
groupby : str
292+
groupby
293293
Key in adata.obs for grouping cells
294-
group_a : str
294+
group_a
295295
First group name
296-
group_b : str
296+
group_b
297297
Second group name
298-
n_bootstrap : int
298+
n_bootstrap
299299
Number of bootstrap iterations
300-
random_state : int
300+
random_state
301301
Random seed for reproducibility
302302
303303
Returns
304304
-------
305-
result : MeanVar
305+
result
306306
Named tuple containing mean and variance of bootstrapped distances
307307
308308
Examples

src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class BaseMetric(ABC):
1818
1919
Parameters
2020
----------
21-
obsm_key : str
21+
obsm_key
2222
Key in adata.obsm for embeddings (default: 'X_pca')
2323
"""
2424

@@ -42,17 +42,17 @@ def pairwise(
4242
4343
Parameters
4444
----------
45-
adata : AnnData
45+
adata
4646
Annotated data matrix
47-
groupby : str
47+
groupby
4848
Key in adata.obs for grouping cells
49-
groups : Sequence[str] | None
49+
groups
5050
Specific groups to compute (if None, use all)
51-
bootstrap : bool
51+
bootstrap
5252
Whether to compute bootstrap variance estimates
53-
n_bootstrap : int
53+
n_bootstrap
5454
Number of bootstrap iterations (if bootstrap=True)
55-
random_state : int
55+
random_state
5656
Random seed for reproducibility
5757
5858
Returns
@@ -78,19 +78,19 @@ def onesided_distances(
7878
7979
Parameters
8080
----------
81-
adata : AnnData
81+
adata
8282
Annotated data matrix
83-
groupby : str
83+
groupby
8484
Key in adata.obs for grouping cells
85-
selected_group : str
85+
selected_group
8686
Reference group to compute distances from
87-
groups : Sequence[str] | None
87+
groups
8888
Specific groups to compute distances to (if None, use all)
89-
bootstrap : bool
89+
bootstrap
9090
Whether to compute bootstrap variance estimates
91-
n_bootstrap : int
91+
n_bootstrap
9292
Number of bootstrap iterations (if bootstrap=True)
93-
random_state : int
93+
random_state
9494
Random seed for reproducibility
9595
9696
Returns
@@ -117,24 +117,24 @@ def bootstrap(
117117
118118
Parameters
119119
----------
120-
adata : AnnData
120+
adata
121121
Annotated data matrix
122-
groupby : str
122+
groupby
123123
Key in adata.obs for grouping cells
124-
group_a : str
124+
group_a
125125
First group name
126-
group_b : str
126+
group_b
127127
Second group name
128-
n_bootstrap : int
128+
n_bootstrap
129129
Number of bootstrap iterations
130-
random_state : int
130+
random_state
131131
Random seed for reproducibility
132132
133133
Returns
134134
-------
135-
mean : float
135+
mean
136136
Bootstrap mean distance
137-
variance : float
137+
variance
138138
Bootstrap variance
139139
"""
140140
raise NotImplementedError(

0 commit comments

Comments
 (0)