@@ -87,11 +87,21 @@ def __hash__(self):
87
87
return self .name .__hash__ ()
88
88
89
89
90
- class DenseFeat (namedtuple ('DenseFeat' , ['name' , 'dimension' , 'dtype' ])):
90
+ class DenseFeat (namedtuple ('DenseFeat' , ['name' , 'dimension' , 'dtype' , 'transform_fn' ])):
91
+ """ Dense feature
92
+ Args:
93
+ name: feature name,
94
+ dimension: dimension of the feature, default = 1.
95
+ dtype: dtype of the feature, default="float32".
96
+ transform_fn: If not None, a function that can be used to transfrom
97
+ values of the feature. the function takes the input Tensor as its
98
+ argument, and returns the output Tensor.
99
+ (e.g. lambda x: (x - 3.0) / 4.2).
100
+ """
91
101
__slots__ = ()
92
102
93
- def __new__ (cls , name , dimension = 1 , dtype = "float32" ):
94
- return super (DenseFeat , cls ).__new__ (cls , name , dimension , dtype )
103
+ def __new__ (cls , name , dimension = 1 , dtype = "float32" , transform_fn = None ):
104
+ return super (DenseFeat , cls ).__new__ (cls , name , dimension , dtype , transform_fn )
95
105
96
106
def __hash__ (self ):
97
107
return self .name .__hash__ ()
0 commit comments