Skip to content

Commit 06afeee

Browse files
authored
add option to choose axis in labeler (#48)
* add option to choose axis in labeler * use cupy argmax * lint * fix test case to use series
1 parent 9561e53 commit 06afeee

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

crossfit/op/label.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Union
22

33
import cudf
4+
import cupy as cp
45

56
from crossfit.op.base import Op
67

@@ -12,29 +13,27 @@ def __init__(
1213
cols=None,
1314
keep_cols=None,
1415
pre=None,
15-
keep_prob: bool = False,
1616
suffix: str = "labels",
17+
axis=-1,
1718
):
1819
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
1920
self.labels = labels
20-
self.keep_prob = keep_prob
2121
self.suffix = suffix
22+
self.axis = axis
2223

2324
def call_column(self, data: cudf.Series) -> cudf.Series:
2425
if isinstance(data, cudf.DataFrame):
2526
raise ValueError(
2627
"data must be a Series, got DataFrame. Add a pre step to convert to Series"
2728
)
2829

29-
num_labels = len(data.iloc[0])
30-
if len(self.labels) != num_labels:
31-
raise ValueError(
32-
f"The number of provided labels is {len(self.labels)} "
33-
f"but there are {num_labels} in data."
34-
)
30+
shape = (data.size,) + cp.asarray(data.iloc[0]).shape
31+
scores = data.list.leaves.values.reshape(shape)
32+
classes = scores.argmax(self.axis)
33+
34+
if len(classes.shape) > 1:
35+
raise RuntimeError(f"Max category of the axis {self.axis} of data is not a 1-d array.")
3536

36-
scores = data.list.leaves.values.reshape(-1, num_labels)
37-
classes = scores.argmax(-1)
3837
labels_map = {i: self.labels[i] for i in range(len(self.labels))}
3938

4039
return cudf.Series(classes).map(labels_map)
@@ -60,7 +59,7 @@ def call(self, data: Union[cudf.Series, cudf.DataFrame]) -> Union[cudf.Series, c
6059
def meta(self):
6160
labeled = {"labels": "string"}
6261

63-
if len(self.cols) > 1:
62+
if self.cols and len(self.cols) > 1:
6463
labeled = {
6564
self._construct_name(col, suffix): dtype
6665
for col in self.cols

tests/op/test_label.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
3+
cudf = pytest.importorskip("cudf")
4+
5+
import crossfit as cf # noqa: E402
6+
7+
8+
def test_labeler_basic():
9+
df = cudf.Series(
10+
[
11+
[0.1, 0.2, 0.5],
12+
[0.2, 0.1, 0.3],
13+
[0.3, 0.2, 0.1],
14+
[0.2, 0.3, 0.1],
15+
]
16+
)
17+
labeler = cf.op.Labeler(list("abc"))
18+
results = labeler(df)
19+
20+
assert results.to_pandas().values.tolist() == ["c", "c", "a", "b"]
21+
22+
23+
def test_labeler_first_axis():
24+
df = cudf.Series(
25+
[
26+
[0.1, 0.2, 0.5],
27+
[0.2, 0.1, 0.3],
28+
[0.3, 0.2, 0.1],
29+
[0.2, 0.3, 0.1],
30+
]
31+
)
32+
labeler = cf.op.Labeler(list("abcd"), axis=0)
33+
results = labeler(df)
34+
35+
assert results.to_pandas().values.tolist() == ["c", "d", "a"]

0 commit comments

Comments
 (0)