@@ -121,9 +121,9 @@ def test_model_training_and_inference(self):
121
121
122
122
# Test model inference predictions.
123
123
expected_predictions = {
124
- keys .TwoTowerPredictionKeys . CONTROL : tf .zeros ((10 , 1 )),
125
- keys .TwoTowerPredictionKeys . TREATMENT : 3 * tf .ones ((10 , 1 )),
126
- keys .TwoTowerPredictionKeys . UPLIFT : 3 * tf .ones ((10 , 1 )),
124
+ keys .TwoTowerOutputKeys . CONTROL_PREDICTIONS : tf .zeros ((10 , 1 )),
125
+ keys .TwoTowerOutputKeys . TREATMENT_PREDICTIONS : 3 * tf .ones ((10 , 1 )),
126
+ keys .TwoTowerOutputKeys . UPLIFT_PREDICTIONS : 3 * tf .ones ((10 , 1 )),
127
127
}
128
128
self .assertAllClose (expected_predictions , model .predict (dataset ))
129
129
@@ -132,31 +132,39 @@ def test_model_training_and_inference(self):
132
132
"testcase_name" : "identity" ,
133
133
"inverse_link_fn" : tf .identity ,
134
134
"expected_predictions" : {
135
- keys .TwoTowerPredictionKeys . CONTROL : (
135
+ keys .TwoTowerOutputKeys . CONTROL_PREDICTIONS : (
136
136
tf .ones ((3 , 1 )) * - 1.0
137
137
), # 1 - 2 = -1
138
- keys .TwoTowerPredictionKeys . TREATMENT : (
138
+ keys .TwoTowerOutputKeys . TREATMENT_PREDICTIONS : (
139
139
tf .ones ((3 , 1 )) * 4.0
140
140
), # 1 + 3 = 4
141
- keys .TwoTowerPredictionKeys . UPLIFT : tf .ones ((3 , 1 )) * 5.0 ,
141
+ keys .TwoTowerOutputKeys . UPLIFT_PREDICTIONS : tf .ones ((3 , 1 )) * 5.0 ,
142
142
},
143
143
},
144
144
{
145
145
"testcase_name" : "abs" ,
146
146
"inverse_link_fn" : tf .math .abs ,
147
147
"expected_predictions" : {
148
- keys .TwoTowerPredictionKeys .CONTROL : tf .ones ((3 , 1 )) * 1.0 ,
149
- keys .TwoTowerPredictionKeys .TREATMENT : tf .ones ((3 , 1 )) * 4.0 ,
150
- keys .TwoTowerPredictionKeys .UPLIFT : tf .ones ((3 , 1 )) * 3.0 ,
148
+ keys .TwoTowerOutputKeys .CONTROL_PREDICTIONS : (
149
+ tf .ones ((3 , 1 )) * 1.0
150
+ ),
151
+ keys .TwoTowerOutputKeys .TREATMENT_PREDICTIONS : (
152
+ tf .ones ((3 , 1 )) * 4.0
153
+ ),
154
+ keys .TwoTowerOutputKeys .UPLIFT_PREDICTIONS : tf .ones ((3 , 1 )) * 3.0 ,
151
155
},
152
156
},
153
157
{
154
158
"testcase_name" : "relu" ,
155
159
"inverse_link_fn" : tf_keras .activations .relu ,
156
160
"expected_predictions" : {
157
- keys .TwoTowerPredictionKeys .CONTROL : tf .ones ((3 , 1 )) * 0.0 ,
158
- keys .TwoTowerPredictionKeys .TREATMENT : tf .ones ((3 , 1 )) * 4.0 ,
159
- keys .TwoTowerPredictionKeys .UPLIFT : tf .ones ((3 , 1 )) * 4.0 ,
161
+ keys .TwoTowerOutputKeys .CONTROL_PREDICTIONS : (
162
+ tf .ones ((3 , 1 )) * 0.0
163
+ ),
164
+ keys .TwoTowerOutputKeys .TREATMENT_PREDICTIONS : (
165
+ tf .ones ((3 , 1 )) * 4.0
166
+ ),
167
+ keys .TwoTowerOutputKeys .UPLIFT_PREDICTIONS : tf .ones ((3 , 1 )) * 4.0 ,
160
168
},
161
169
},
162
170
)
0 commit comments