@@ -828,6 +828,15 @@ def _ensure_bytes(l):
828
828
raise ValueError ('expected bytes, found %r' % l )
829
829
830
830
831
+ def _ensure_text (l ):
832
+ if isinstance (l , text_type ):
833
+ return l
834
+ elif isinstance (l , binary_type ):
835
+ return text_type (l , 'ascii' )
836
+ else :
837
+ raise ValueError ('expected text, found %r' % l )
838
+
839
+
831
840
class Categorize (Codec ):
832
841
"""Filter encoding categorical string data as integers.
833
842
@@ -862,10 +871,13 @@ class Categorize(Codec):
862
871
codec_id = 'categorize'
863
872
864
873
def __init__ (self , labels , dtype , astype = 'u1' ):
865
- self .labels = [_ensure_bytes (l ) for l in labels ]
866
874
self .dtype = np .dtype (dtype )
867
- if self .dtype .kind != 'S' :
868
- raise ValueError ('only string data types are supported' )
875
+ if self .dtype .kind == 'S' :
876
+ self .labels = [_ensure_bytes (l ) for l in labels ]
877
+ elif self .dtype .kind == 'U' :
878
+ self .labels = [_ensure_text (l ) for l in labels ]
879
+ else :
880
+ raise ValueError ('data type not supported' )
869
881
self .astype = np .dtype (astype )
870
882
871
883
def encode (self , buf ):
@@ -909,7 +921,7 @@ def decode(self, buf, out=None):
909
921
def get_config (self ):
910
922
config = dict ()
911
923
config ['id' ] = self .codec_id
912
- config ['labels' ] = [text_type ( l , 'ascii' ) for l in self .labels ]
924
+ config ['labels' ] = [_ensure_text ( l ) for l in self .labels ]
913
925
config ['dtype' ] = encode_dtype (self .dtype )
914
926
config ['astype' ] = encode_dtype (self .astype )
915
927
return config
@@ -922,8 +934,12 @@ def from_config(cls, config):
922
934
return cls (labels = labels , dtype = dtype , astype = astype )
923
935
924
936
def __repr__ (self ):
925
- r = '%s(dtype=%s, astype=%s, labels=%r)' % \
926
- (type (self ).__name__ , self .dtype , self .astype , self .labels )
937
+ # make sure labels part is not too long
938
+ labels = repr (self .labels [:3 ])
939
+ if len (self .labels ) > 3 :
940
+ labels = labels [:- 1 ] + ', ...]'
941
+ r = '%s(dtype=%s, astype=%s, labels=%s)' % \
942
+ (type (self ).__name__ , self .dtype , self .astype , labels )
927
943
return r
928
944
929
945
0 commit comments