We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6e98ba7 commit 0441e55Copy full SHA for 0441e55
tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py
@@ -311,17 +311,23 @@ class FieldWiseEmbedding(BasicEmbedding):
311
```python
312
nslots = 3
313
@tf.function
314
- def map_slot_fn(feature_id):
+ def feature_to_slot(feature_id):
315
field_id = tf.math.mod(feature_id, nslots)
316
return field_id
317
318
ids = tf.constant([[23, 12, 0], [9, 13, 10]], dtype=tf.int64)
319
- embedding = de.layers.FieldWiseEmbedding(1, nslots, map_slot_fn)
+ embedding = de.layers.FieldWiseEmbedding(2,
320
+ nslots,
321
+ slot_map_fn=feature_to_slot,
322
+ initializer=tf.keras.initializer.Zeros())
323
+
324
+ out = embedding(ids)
325
+ # [[[0., 0.], [0., 0.], [0., 1.]]
326
+ # [[0., 0.], [0., 0.], [0., 1.]]]
327
328
prepared_keys = tf.range(0, 100, dtype=tf.int64)
329
prepared_values = tf.ones((100, 2), dtype=tf.float32)
330
embedding.params.upsert(prepared_keys, prepared_values)
-
331
out = embedding(ids)
332
# [[2., 2.], [0., 0.], [1., 1.]]
333
# [[1., 1.], [2., 2.], [0., 0.]]
0 commit comments