Skip to content

Commit 143d088

Browse files
angelayifacebook-github-bot
authored andcommitted
Add dtype to kt_regroup input (#3250)
Summary: Pull Request resolved: #3250 Currently ir_kt_regroup always returns a float32, which is incorrect if the KTRegroupAsDict's emb_type is set Reviewed By: malaybag Differential Revision: D79422720 fbshipit-source-id: 2e711a7dc186b80226b2631ba94af741e05df4b7
1 parent ac5fd1f commit 143d088

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

torchrec/ir/serializer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323

2424
from torchrec.ir.types import SerializerInterface
2525
from torchrec.ir.utils import logging, qualname
26-
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
26+
from torchrec.modules.embedding_configs import (
27+
data_type_to_dtype,
28+
DataType,
29+
EmbeddingBagConfig,
30+
PoolingType,
31+
)
2732
from torchrec.modules.embedding_modules import EmbeddingBagCollection
2833
from torchrec.modules.feature_processor_ import (
2934
FeatureProcessor,
@@ -138,7 +143,17 @@ def kt_regroup_meta_forward(
138143
for i, group in enumerate(op_module._groups):
139144
out_lengths[i] = sum(lengths_dict[key] for key in group)
140145
arg_list = [kt.values() for kt in keyed_tensors]
141-
outputs = torch.ops.torchrec.ir_kt_regroup(arg_list, batch_size, out_lengths)
146+
dtype = (
147+
data_type_to_dtype(op_module._emb_dtype)
148+
if op_module._emb_dtype
149+
else torch.float32
150+
)
151+
outputs = torch.ops.torchrec.ir_kt_regroup(
152+
arg_list,
153+
batch_size,
154+
out_lengths,
155+
dtype=dtype,
156+
)
142157
return dict(zip(op_module._keys, outputs))
143158

144159

torchrec/ir/tests/test_serializer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
mark_dynamic_kjt,
2424
qualname,
2525
)
26-
from torchrec.modules.embedding_configs import EmbeddingBagConfig
26+
from torchrec.modules.embedding_configs import data_type_to_dtype, EmbeddingBagConfig
2727
from torchrec.modules.embedding_modules import EmbeddingBagCollection
2828
from torchrec.modules.feature_processor_ import (
2929
PositionWeightedModule,
@@ -822,6 +822,14 @@ def forward(
822822
preserve_module_call_signature=(tuple(sparse_fqns)),
823823
)
824824

825+
for node in ep.graph.nodes:
826+
if (
827+
node.op == "call_function"
828+
and node.target == torch.ops.torchrec.ir_kt_regroup.default
829+
):
830+
for meta in node.meta["val"]:
831+
self.assertEqual(meta.dtype, data_type_to_dtype(data_type))
832+
825833
unflatten_ep = torch.export.unflatten(ep)
826834
deserialized = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
827835
self.assertEqual(deserialized.regroup._emb_dtype, data_type) # pyre-ignore[16]

torchrec/ir/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,26 @@ def ir_emb_lookup_fake(
7171

7272
@torch.library.custom_op("torchrec::ir_kt_regroup", mutates_args={})
7373
def ir_kt_regroup_impl(
74-
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
74+
tensors: List[Optional[torch.Tensor]],
75+
batch_size: int,
76+
dims: List[int],
77+
dtype: torch.dtype = torch.float32,
7578
) -> List[torch.Tensor]:
7679
device = get_device(tensors)
7780
logger.info(f"torch.ops.torchrec.ir_kt_regroup -> ({batch_size}, {dims}) {device}")
78-
return [torch.empty(batch_size, dim, device=device) for dim in dims]
81+
return [torch.empty(batch_size, dim, device=device, dtype=dtype) for dim in dims]
7982

8083

8184
@torch.library.register_fake("torchrec::ir_kt_regroup")
8285
def ir_kt_regroup_fake(
83-
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
86+
tensors: List[Optional[torch.Tensor]],
87+
batch_size: int,
88+
dims: List[int],
89+
dtype: torch.dtype = torch.float32,
8490
) -> List[torch.Tensor]:
8591
device = get_device(tensors)
8692
logger.info(f"ir_kt_regroup_fake -> ({batch_size}, {dims}) {device}")
87-
return [torch.empty(batch_size, dim, device=device) for dim in dims]
93+
return [torch.empty(batch_size, dim, device=device, dtype=dtype) for dim in dims]
8894

8995

9096
@torch.library.custom_op("torchrec::ir_dynamic_batch_emb_lookup", mutates_args={})

0 commit comments

Comments
 (0)