1
1
import tensorflow as tf
2
2
import tensorflow_datasets as tfds
3
- import sys
4
3
5
- ENCODDING_SEGMENT_LENGTH = 1000000
4
+ ENCODING_SEGMENT_LENGTH = 1000000
6
5
NON_LETTER_OR_NUMBER_PATTERN = r'[^a-zA-Z0-9]'
7
6
8
7
FAETURES = [
12
11
]
13
12
LABEL = 'verified_purchase'
14
13
14
+ NUM_FEATURE_SLOTS = 0
15
+
15
16
16
17
class _RawFeature (object ):
17
18
"""
@@ -22,13 +23,15 @@ def __init__(self, dtype, category):
22
23
if not isinstance (category , int ):
23
24
raise TypeError ('category must be an integer.' )
24
25
self .category = category
26
+ global NUM_FEATURE_SLOTS
27
+ NUM_FEATURE_SLOTS = max (NUM_FEATURE_SLOTS , self .category )
25
28
26
29
def encode (self , tensor ):
27
30
raise NotImplementedError
28
31
29
32
def match_category (self , tensor ):
30
- min_code = self .category * ENCODDING_SEGMENT_LENGTH
31
- max_code = (self .category + 1 ) * ENCODDING_SEGMENT_LENGTH
33
+ min_code = self .category * ENCODING_SEGMENT_LENGTH
34
+ max_code = (self .category + 1 ) * ENCODING_SEGMENT_LENGTH
32
35
mask = tf .math .logical_and (tf .greater_equal (tensor , min_code ),
33
36
tf .less (tensor , max_code ))
34
37
return mask
@@ -40,8 +43,8 @@ def __init__(self, dtype, category):
40
43
super (_StringFeature , self ).__init__ (dtype , category )
41
44
42
45
def encode (self , tensor ):
43
- tensor = tf .strings .to_hash_bucket_fast (tensor , ENCODDING_SEGMENT_LENGTH )
44
- tensor += ENCODDING_SEGMENT_LENGTH * self .category
46
+ tensor = tf .strings .to_hash_bucket_fast (tensor , ENCODING_SEGMENT_LENGTH )
47
+ tensor += ENCODING_SEGMENT_LENGTH * self .category
45
48
return tensor
46
49
47
50
@@ -53,8 +56,8 @@ def __init__(self, dtype, category):
53
56
def encode (self , tensor ):
54
57
tensor = tf .strings .regex_replace (tensor , NON_LETTER_OR_NUMBER_PATTERN , ' ' )
55
58
tensor = tf .strings .split (tensor , sep = ' ' ).to_tensor ('' )
56
- tensor = tf .strings .to_hash_bucket_fast (tensor , ENCODDING_SEGMENT_LENGTH )
57
- tensor += ENCODDING_SEGMENT_LENGTH * self .category
59
+ tensor = tf .strings .to_hash_bucket_fast (tensor , ENCODING_SEGMENT_LENGTH )
60
+ tensor += ENCODING_SEGMENT_LENGTH * self .category
58
61
return tensor
59
62
60
63
@@ -65,23 +68,23 @@ def __init__(self, dtype, category):
65
68
66
69
def encode (self , tensor ):
67
70
tensor = tf .as_string (tensor )
68
- tensor = tf .strings .to_hash_bucket_fast (tensor , ENCODDING_SEGMENT_LENGTH )
69
- tensor += ENCODDING_SEGMENT_LENGTH * self .category
71
+ tensor = tf .strings .to_hash_bucket_fast (tensor , ENCODING_SEGMENT_LENGTH )
72
+ tensor += ENCODING_SEGMENT_LENGTH * self .category
70
73
return tensor
71
74
72
75
73
76
FEATURE_AND_ENCODER = {
74
- 'customer_id' : _StringFeature (tf .string , 1 ),
75
- 'helpful_votes' : _IntegerFeature (tf .int32 , 2 ),
76
- 'product_category' : _StringFeature (tf .string , 3 ),
77
- 'product_id' : _StringFeature (tf .string , 4 ),
78
- 'product_parent' : _StringFeature (tf .string , 5 ),
79
- 'product_title' : _TextFeature (tf .string , 6 ),
80
- #'review_body ': _TextFeature(tf.string, 7), # bad feature
81
- 'review_headline ' : _TextFeature (tf .string , 8 ),
82
- 'review_id ' : _StringFeature (tf .string , 9 ),
83
- 'star_rating ' : _IntegerFeature (tf .int32 , 10 ),
84
- 'total_votes ' : _IntegerFeature (tf .int32 , 11 ),
77
+ 'customer_id' : _StringFeature (tf .string , 0 ),
78
+ 'helpful_votes' : _IntegerFeature (tf .int32 , 1 ),
79
+ 'product_category' : _StringFeature (tf .string , 2 ),
80
+ 'product_id' : _StringFeature (tf .string , 3 ),
81
+ 'product_parent' : _StringFeature (tf .string , 4 ),
82
+ 'product_title' : _TextFeature (tf .string , 5 ),
83
+ 'review_headline ' : _TextFeature (tf .string , 6 ),
84
+ 'review_id ' : _StringFeature (tf .string , 7 ),
85
+ 'star_rating ' : _IntegerFeature (tf .int32 , 8 ),
86
+ 'total_votes ' : _IntegerFeature (tf .int32 , 9 ),
87
+ #'review_body ': _TextFeature (tf.string, 10), # bad feature
85
88
}
86
89
87
90
@@ -99,6 +102,12 @@ def encode_feature(data):
99
102
return collected_features
100
103
101
104
105
+ @tf .function
106
+ def get_category (tensor ):
107
+ x = tf .math .floordiv (tensor , ENCODING_SEGMENT_LENGTH )
108
+ return x
109
+
110
+
102
111
def get_labels (data ):
103
112
return data ['verified_purchase' ]
104
113
0 commit comments