Skip to content

Commit 0441e55

Browse files
Lifannrhdong
authored andcommitted
fix(comment): Fix error in comment.
1 parent 6e98ba7 commit 0441e55

File tree

1 file changed

+9
-3
lines changed
  • tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers

1 file changed

+9
-3
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,17 +311,23 @@ class FieldWiseEmbedding(BasicEmbedding):
311311
```python
312312
nslots = 3
313313
@tf.function
314-
def map_slot_fn(feature_id):
314+
def feature_to_slot(feature_id):
315315
field_id = tf.math.mod(feature_id, nslots)
316316
return field_id
317317
318318
ids = tf.constant([[23, 12, 0], [9, 13, 10]], dtype=tf.int64)
319-
embedding = de.layers.FieldWiseEmbedding(1, nslots, map_slot_fn)
319+
embedding = de.layers.FieldWiseEmbedding(2,
320+
nslots,
321+
slot_map_fn=feature_to_slot,
322+
initializer=tf.keras.initializer.Zeros())
323+
324+
out = embedding(ids)
325+
# [[[0., 0.], [0., 0.], [0., 1.]]
326+
# [[0., 0.], [0., 0.], [0., 1.]]]
320327
321328
prepared_keys = tf.range(0, 100, dtype=tf.int64)
322329
prepared_values = tf.ones((100, 2), dtype=tf.float32)
323330
embedding.params.upsert(prepared_keys, prepared_values)
324-
325331
out = embedding(ids)
326332
# [[2., 2.], [0., 0.], [1., 1.]]
327333
# [[1., 1.], [2., 2.], [0., 0.]]

0 commit comments

Comments
 (0)