Skip to content

Commit 0167503

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Refactoring and moving the EBC-only logic into ‎embedding_collection_wrappers.py (#3251)
Summary: Pull Request resolved: #3251 - Moved the EBC-only logic into a separate file `embedding_collection_wrappers.py` - Removed linter issues Reviewed By: spmex Differential Revision: D79512602 fbshipit-source-id: 89357ef2383afabcbfacfbb4ff30f49154e4bea1
1 parent ee1da3a commit 0167503

File tree

5 files changed

+725
-672
lines changed

5 files changed

+725
-672
lines changed

examples/zch/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414

1515
from torchrec import EmbeddingConfig, KeyedJaggedTensor
16-
from torchrec.distributed.benchmark.benchmark_utils import get_inputs
1716
from tqdm import tqdm
1817

1918
from .sparse_arch import SparseArch

torchrec/distributed/benchmark/benchmark_inference.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@
1919
import torch
2020

2121
from torchrec.distributed.benchmark.benchmark_utils import (
22-
benchmark_module,
2322
BenchmarkResult,
2423
CompileMode,
2524
DLRM_NUM_EMBEDDINGS_PER_FEATURE,
2625
EMBEDDING_DIM,
27-
get_tables,
2826
init_argparse_and_args,
2927
write_report,
3028
)
29+
from torchrec.distributed.benchmark.embedding_collection_wrappers import (
30+
benchmark_ebc_module,
31+
get_tables,
32+
)
3133
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
3234
from torchrec.distributed.test_utils.infer_utils import (
3335
TestQuantEBCSharder,
@@ -84,7 +86,7 @@ def benchmark_qec(args: argparse.Namespace, output_dir: str) -> List[BenchmarkRe
8486
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
8587
}
8688

87-
return benchmark_module(
89+
return benchmark_ebc_module(
8890
module=module,
8991
sharder=sharder,
9092
sharding_types=BENCH_SHARDING_TYPES,
@@ -118,7 +120,7 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR
118120
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
119121
}
120122

121-
return benchmark_module(
123+
return benchmark_ebc_module(
122124
module=module,
123125
sharder=sharder,
124126
sharding_types=BENCH_SHARDING_TYPES,
@@ -153,7 +155,7 @@ def benchmark_qec_unsharded(
153155
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
154156
}
155157

156-
return benchmark_module(
158+
return benchmark_ebc_module(
157159
module=module,
158160
sharder=sharder,
159161
sharding_types=[],
@@ -190,7 +192,7 @@ def benchmark_qebc_unsharded(
190192
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
191193
}
192194

193-
return benchmark_module(
195+
return benchmark_ebc_module(
194196
module=module,
195197
sharder=sharder,
196198
sharding_types=[],

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
from typing import List, Optional, Tuple
1818

1919
import torch
20-
2120
from torchrec.distributed.benchmark.benchmark_utils import (
22-
benchmark_module,
2321
BenchmarkResult,
2422
CompileMode,
25-
get_tables,
2623
init_argparse_and_args,
2724
set_embedding_config,
2825
write_report,
2926
)
27+
from torchrec.distributed.benchmark.embedding_collection_wrappers import (
28+
benchmark_ebc_module,
29+
get_tables,
30+
)
3031
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
3132
from torchrec.distributed.test_utils.test_model import TestEBCSharder
3233
from torchrec.distributed.types import DataType
@@ -106,7 +107,7 @@ def benchmark_ebc(
106107

107108
args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings
108109

109-
return benchmark_module(
110+
return benchmark_ebc_module(
110111
module=module,
111112
sharder=sharder,
112113
sharding_types=BENCH_SHARDING_TYPES,

0 commit comments

Comments
 (0)