Skip to content

Commit 00aa43b

Browse files
No public description
PiperOrigin-RevId: 595481514
1 parent 424ed6b commit 00aa43b

File tree

3 files changed

+34
-20
lines changed

3 files changed

+34
-20
lines changed

official/recommendation/uplift/keys.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ class StrEnum(str, enum.Enum):
2020
"""An Enum represented by a string."""
2121

2222

23-
class TwoTowerPredictionKeys(StrEnum):
24-
"""Keys for prediction tensors."""
23+
class TwoTowerOutputKeys(StrEnum):
24+
"""Keys for training and inference output tensors."""
2525

26-
UPLIFT = "uplift_predictions"
27-
CONTROL = "control_predictions"
28-
TREATMENT = "treatment_predictions"
26+
CONTROL_PREDICTIONS = "control_predictions"
27+
TREATMENT_PREDICTIONS = "treatment_predictions"
28+
UPLIFT_PREDICTIONS = "uplift_predictions"
29+
IS_TREATMENT = "is_treatment"
30+
TRUE_LOGITS = "true_logits"

official/recommendation/uplift/models/two_tower_uplift_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,13 @@ def predict_step(self, data) -> dict[str, tf.Tensor]:
9696
outputs = super().predict_step(data)
9797

9898
return {
99-
keys.TwoTowerPredictionKeys.CONTROL: outputs.control_predictions,
100-
keys.TwoTowerPredictionKeys.TREATMENT: outputs.treatment_predictions,
101-
keys.TwoTowerPredictionKeys.UPLIFT: outputs.uplift,
99+
keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS: (
100+
outputs.control_predictions
101+
),
102+
keys.TwoTowerOutputKeys.TREATMENT_PREDICTIONS: (
103+
outputs.treatment_predictions
104+
),
105+
keys.TwoTowerOutputKeys.UPLIFT_PREDICTIONS: outputs.uplift,
102106
}
103107

104108
def get_config(self) -> Mapping[str, Any]:

official/recommendation/uplift/models/two_tower_uplift_model_test.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def test_model_training_and_inference(self):
121121

122122
# Test model inference predictions.
123123
expected_predictions = {
124-
keys.TwoTowerPredictionKeys.CONTROL: tf.zeros((10, 1)),
125-
keys.TwoTowerPredictionKeys.TREATMENT: 3 * tf.ones((10, 1)),
126-
keys.TwoTowerPredictionKeys.UPLIFT: 3 * tf.ones((10, 1)),
124+
keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS: tf.zeros((10, 1)),
125+
keys.TwoTowerOutputKeys.TREATMENT_PREDICTIONS: 3 * tf.ones((10, 1)),
126+
keys.TwoTowerOutputKeys.UPLIFT_PREDICTIONS: 3 * tf.ones((10, 1)),
127127
}
128128
self.assertAllClose(expected_predictions, model.predict(dataset))
129129

@@ -132,31 +132,39 @@ def test_model_training_and_inference(self):
132132
"testcase_name": "identity",
133133
"inverse_link_fn": tf.identity,
134134
"expected_predictions": {
135-
keys.TwoTowerPredictionKeys.CONTROL: (
135+
keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS: (
136136
tf.ones((3, 1)) * -1.0
137137
), # 1 - 2 = -1
138-
keys.TwoTowerPredictionKeys.TREATMENT: (
138+
keys.TwoTowerOutputKeys.TREATMENT_PREDICTIONS: (
139139
tf.ones((3, 1)) * 4.0
140140
), # 1 + 3 = 4
141-
keys.TwoTowerPredictionKeys.UPLIFT: tf.ones((3, 1)) * 5.0,
141+
keys.TwoTowerOutputKeys.UPLIFT_PREDICTIONS: tf.ones((3, 1)) * 5.0,
142142
},
143143
},
144144
{
145145
"testcase_name": "abs",
146146
"inverse_link_fn": tf.math.abs,
147147
"expected_predictions": {
148-
keys.TwoTowerPredictionKeys.CONTROL: tf.ones((3, 1)) * 1.0,
149-
keys.TwoTowerPredictionKeys.TREATMENT: tf.ones((3, 1)) * 4.0,
150-
keys.TwoTowerPredictionKeys.UPLIFT: tf.ones((3, 1)) * 3.0,
148+
keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS: (
149+
tf.ones((3, 1)) * 1.0
150+
),
151+
keys.TwoTowerOutputKeys.TREATMENT_PREDICTIONS: (
152+
tf.ones((3, 1)) * 4.0
153+
),
154+
keys.TwoTowerOutputKeys.UPLIFT_PREDICTIONS: tf.ones((3, 1)) * 3.0,
151155
},
152156
},
153157
{
154158
"testcase_name": "relu",
155159
"inverse_link_fn": tf_keras.activations.relu,
156160
"expected_predictions": {
157-
keys.TwoTowerPredictionKeys.CONTROL: tf.ones((3, 1)) * 0.0,
158-
keys.TwoTowerPredictionKeys.TREATMENT: tf.ones((3, 1)) * 4.0,
159-
keys.TwoTowerPredictionKeys.UPLIFT: tf.ones((3, 1)) * 4.0,
161+
keys.TwoTowerOutputKeys.CONTROL_PREDICTIONS: (
162+
tf.ones((3, 1)) * 0.0
163+
),
164+
keys.TwoTowerOutputKeys.TREATMENT_PREDICTIONS: (
165+
tf.ones((3, 1)) * 4.0
166+
),
167+
keys.TwoTowerOutputKeys.UPLIFT_PREDICTIONS: tf.ones((3, 1)) * 4.0,
160168
},
161169
},
162170
)

0 commit comments

Comments
 (0)