Skip to content

Commit a2fdb42

Browse files
Kathy Xufacebook-github-bot
authored andcommitted
add unit test for offloading (metaheader included) (#3202)
Summary: Pull Request resolved: #3202 X-link: facebookresearch/FBGEMM#1560 Added unit test for the following conditions, and fixed related bugs: 1. ZCH fused optimizer with offloading 2. Added DRAM kernel for stat dict loading (with metaheader)and numerical accuracy 3. Applied return whole row for DRAM kernel. Reviewed By: emlin, bobbyliujb Differential Revision: D77474510 fbshipit-source-id: e307077acce9f40cf733d8382bc45c6a931ae1e4
1 parent 94fc482 commit a2fdb42

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def __init__( # noqa C901
430430
)
431431

432432
all_optimizer_states = emb_module.get_optimizer_state(
433-
sorted_id_tensor=sorted_id_tensors
433+
sorted_id_tensor=sorted_id_tensors,
434434
)
435435
opt_param_list = [param["momentum1"] for param in all_optimizer_states]
436436
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -421,20 +421,21 @@ def _compare_models(
421421
src_wid = sd1[wid_key].local_shards()[local_shard_id].tensor
422422
dst_wid = sd2[wid_key].local_shards()[local_shard_id].tensor
423423

424-
sorted_src_wid, _ = torch.sort(src_wid.view(-1))
425-
sorted_dst_wid, _ = torch.sort(dst_wid.view(-1))
424+
sorted_src_wid = torch.sort(src_wid.view(-1))[0]
425+
sorted_dst_wid = torch.sort(dst_wid.view(-1))[0]
426426
assert torch.equal(sorted_src_wid, sorted_dst_wid)
427-
src_tensor = src.tensor.get_weights_by_ids(src_wid)
428-
dst_tensor = dst.tensor.get_weights_by_ids(dst_wid)
427+
# kvz zch emb table comparison, id is non-continuous
428+
src_tensor = src.tensor.get_weights_by_ids(sorted_src_wid)
429+
dst_tensor = dst.tensor.get_weights_by_ids(sorted_dst_wid)
429430
else:
430431
# normal ssd offloading emb table comparison
431432
src_tensor = src.tensor.full_tensor()
432433
dst_tensor = dst.tensor.full_tensor()
433434
else:
434-
src_tensor = src.tensor
435-
dst_tensor = dst.tensor
435+
src_tensor = torch.sort(src.tensor.flatten()).values
436+
dst_tensor = torch.sort(dst.tensor.flatten()).values
436437
if is_deterministic:
437-
self.assertTrue(torch.equal(src_tensor, dst_tensor))
438+
self.assertTrue(torch.allclose(src_tensor, dst_tensor))
438439
else:
439440
rtol, atol = _get_default_rtol_and_atol(src_tensor, dst_tensor)
440441
torch.testing.assert_close(

torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,7 @@ def _copy_fused_modules_into_ssd_emb_modules(
985985
kernel_type=st.sampled_from(
986986
[
987987
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
988+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
988989
]
989990
),
990991
sharding_type=st.sampled_from(
@@ -1387,6 +1388,10 @@ def _copy_ssd_emb_modules(
13871388
pmt2 = sharded_t2.local_shards()[0].tensor
13881389
pmt2.wrapped.set_weights_and_ids(w1, w1_id.view(-1))
13891390

1391+
# Remove the cache to force state dict read from backend again
1392+
emb_module1._split_weights_res = None
1393+
emb_module2._split_weights_res = None
1394+
13901395
# purge after loading. This is needed, since we pass a batch
13911396
# through dmp when instantiating them.
13921397
emb_module1.purge()

0 commit comments

Comments
 (0)