11from typing import List , Union
22
33import cudf
4+ import cupy as cp
45
56from crossfit .op .base import Op
67
@@ -12,29 +13,27 @@ def __init__(
1213 cols = None ,
1314 keep_cols = None ,
1415 pre = None ,
15- keep_prob : bool = False ,
1616 suffix : str = "labels" ,
17+ axis = - 1 ,
1718 ):
1819 super ().__init__ (pre = pre , cols = cols , keep_cols = keep_cols )
1920 self .labels = labels
20- self .keep_prob = keep_prob
2121 self .suffix = suffix
22+ self .axis = axis
2223
2324 def call_column (self , data : cudf .Series ) -> cudf .Series :
2425 if isinstance (data , cudf .DataFrame ):
2526 raise ValueError (
2627 "data must be a Series, got DataFrame. Add a pre step to convert to Series"
2728 )
2829
29- num_labels = len (data .iloc [0 ])
30- if len ( self . labels ) != num_labels :
31- raise ValueError (
32- f"The number of provided labels is { len ( self . labels ) } "
33- f"but there are { num_labels } in data."
34- )
30+ shape = (data .size ,) + cp . asarray ( data . iloc [0 ]). shape
31+ scores = data . list . leaves . values . reshape ( shape )
32+ classes = scores . argmax ( self . axis )
33+
34+ if len ( classes . shape ) > 1 :
35+ raise RuntimeError ( f"Max category of the axis { self . axis } of data is not a 1-d array." )
3536
36- scores = data .list .leaves .values .reshape (- 1 , num_labels )
37- classes = scores .argmax (- 1 )
3837 labels_map = {i : self .labels [i ] for i in range (len (self .labels ))}
3938
4039 return cudf .Series (classes ).map (labels_map )
@@ -60,7 +59,7 @@ def call(self, data: Union[cudf.Series, cudf.DataFrame]) -> Union[cudf.Series, c
6059 def meta (self ):
6160 labeled = {"labels" : "string" }
6261
63- if len (self .cols ) > 1 :
62+ if self . cols and len (self .cols ) > 1 :
6463 labeled = {
6564 self ._construct_name (col , suffix ): dtype
6665 for col in self .cols
0 commit comments