Skip to content

Commit 091ec6b

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Generalize module benchmarking (#3252)
Summary: Pull Request resolved: #3252 Added `benchmark_module` function that benchmarks the given module, sharding type, and planner across multiple GPUs. Reviewed By: aliafzal Differential Revision: D79515598 fbshipit-source-id: 5409096c9a0d6c7c7dfb8a19d640761128362604
1 parent 50f9129 commit 091ec6b

File tree

3 files changed

+490
-216
lines changed

3 files changed

+490
-216
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 2 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,14 @@
1616
3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py
1717
"""
1818

19-
import copy
2019
from abc import ABC, abstractmethod
2120
from dataclasses import dataclass, fields
22-
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
21+
from typing import Dict, List, Optional, Type, Union
2322

2423
import torch
25-
import torch.distributed as dist
26-
from torch import nn, optim
27-
from torch.optim import Optimizer
28-
from torchrec.distributed import DistributedModelParallel
29-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
30-
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
31-
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
32-
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
33-
from torchrec.distributed.planner.types import ParameterConstraints
24+
from torch import nn
3425
from torchrec.distributed.test_utils.test_input import ModelInput
3526
from torchrec.distributed.test_utils.test_model import (
36-
TestEBCSharder,
3727
TestSparseNN,
3828
TestTowerCollectionSparseNN,
3929
TestTowerSparseNN,
@@ -47,7 +37,6 @@
4737
PrefetchTrainPipelineSparseDist,
4838
TrainPipelineSemiSync,
4939
)
50-
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
5140
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
5241
from torchrec.models.dlrm import DLRMWrapper
5342
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -248,55 +237,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
248237
return model_class(**filtered_kwargs)
249238

250239

251-
def generate_tables(
252-
num_unweighted_features: int,
253-
num_weighted_features: int,
254-
embedding_feature_dim: int,
255-
) -> Tuple[
256-
List[EmbeddingBagConfig],
257-
List[EmbeddingBagConfig],
258-
]:
259-
"""
260-
Generate embedding bag configurations for both unweighted and weighted features.
261-
262-
This function creates two lists of EmbeddingBagConfig objects:
263-
1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}"
264-
2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}"
265-
266-
For both types, the number of embeddings scales with the feature index,
267-
calculated as max(i + 1, 100) * 1000.
268-
269-
Args:
270-
num_unweighted_features (int): Number of unweighted features to generate.
271-
num_weighted_features (int): Number of weighted features to generate.
272-
embedding_feature_dim (int): Dimension of the embedding vectors.
273-
274-
Returns:
275-
Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing
276-
two lists - the first for unweighted embedding tables and the second for
277-
weighted embedding tables.
278-
"""
279-
tables = [
280-
EmbeddingBagConfig(
281-
num_embeddings=max(i + 1, 100) * 1000,
282-
embedding_dim=embedding_feature_dim,
283-
name="table_" + str(i),
284-
feature_names=["feature_" + str(i)],
285-
)
286-
for i in range(num_unweighted_features)
287-
]
288-
weighted_tables = [
289-
EmbeddingBagConfig(
290-
num_embeddings=max(i + 1, 100) * 1000,
291-
embedding_dim=embedding_feature_dim,
292-
name="weighted_table_" + str(i),
293-
feature_names=["weighted_feature_" + str(i)],
294-
)
295-
for i in range(num_weighted_features)
296-
]
297-
return tables, weighted_tables
298-
299-
300240
def generate_pipeline(
301241
pipeline_type: str,
302242
emb_lookup_stream: str,
@@ -371,156 +311,6 @@ def generate_pipeline(
371311
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)
372312

373313

374-
def generate_planner(
375-
planner_type: str,
376-
topology: Topology,
377-
tables: Optional[List[EmbeddingBagConfig]],
378-
weighted_tables: Optional[List[EmbeddingBagConfig]],
379-
sharding_type: ShardingType,
380-
compute_kernel: EmbeddingComputeKernel,
381-
batch_sizes: List[int],
382-
pooling_factors: Optional[List[float]],
383-
num_poolings: Optional[List[float]],
384-
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
385-
"""
386-
Generate an embedding sharding planner based on the specified configuration.
387-
388-
Args:
389-
planner_type: Type of planner to use ("embedding" or "hetero")
390-
topology: Network topology for distributed training
391-
tables: List of unweighted embedding tables
392-
weighted_tables: List of weighted embedding tables
393-
sharding_type: Strategy for sharding embedding tables
394-
compute_kernel: Compute kernel to use for embedding tables
395-
batch_sizes: Sizes of each batch
396-
pooling_factors: Pooling factors for each feature of the table
397-
num_poolings: Number of poolings for each feature of the table
398-
399-
Returns:
400-
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner
401-
402-
Raises:
403-
RuntimeError: If an unknown planner type is specified
404-
"""
405-
# Create parameter constraints for tables
406-
constraints = {}
407-
num_batches = len(batch_sizes)
408-
409-
if pooling_factors is None:
410-
pooling_factors = [POOLING_FACTOR] * num_batches
411-
412-
if num_poolings is None:
413-
num_poolings = [NUM_POOLINGS] * num_batches
414-
415-
assert (
416-
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
417-
), "The length of pooling_factors and num_poolings must match the number of batches."
418-
419-
if tables is not None:
420-
for table in tables:
421-
constraints[table.name] = ParameterConstraints(
422-
sharding_types=[sharding_type.value],
423-
compute_kernels=[compute_kernel.value],
424-
device_group="cuda",
425-
pooling_factors=pooling_factors,
426-
num_poolings=num_poolings,
427-
batch_sizes=batch_sizes,
428-
)
429-
430-
if weighted_tables is not None:
431-
for table in weighted_tables:
432-
constraints[table.name] = ParameterConstraints(
433-
sharding_types=[sharding_type.value],
434-
compute_kernels=[compute_kernel.value],
435-
device_group="cuda",
436-
pooling_factors=pooling_factors,
437-
num_poolings=num_poolings,
438-
batch_sizes=batch_sizes,
439-
is_weighted=True,
440-
)
441-
442-
if planner_type == "embedding":
443-
return EmbeddingShardingPlanner(
444-
topology=topology,
445-
constraints=constraints if constraints else None,
446-
)
447-
elif planner_type == "hetero":
448-
topology_groups = {"cuda": topology}
449-
return HeteroEmbeddingShardingPlanner(
450-
topology_groups=topology_groups,
451-
constraints=constraints if constraints else None,
452-
)
453-
else:
454-
raise RuntimeError(f"Unknown planner type: {planner_type}")
455-
456-
457-
def generate_sharded_model_and_optimizer(
458-
model: nn.Module,
459-
sharding_type: str,
460-
kernel_type: str,
461-
pg: dist.ProcessGroup,
462-
device: torch.device,
463-
fused_params: Dict[str, Any],
464-
dense_optimizer: str,
465-
dense_lr: float,
466-
dense_momentum: Optional[float],
467-
dense_weight_decay: Optional[float],
468-
planner: Optional[
469-
Union[
470-
EmbeddingShardingPlanner,
471-
HeteroEmbeddingShardingPlanner,
472-
]
473-
] = None,
474-
) -> Tuple[nn.Module, Optimizer]:
475-
476-
sharder = TestEBCSharder(
477-
sharding_type=sharding_type,
478-
kernel_type=kernel_type,
479-
fused_params=fused_params,
480-
)
481-
sharders = [cast(ModuleSharder[nn.Module], sharder)]
482-
483-
# Use planner if provided
484-
plan = None
485-
if planner is not None:
486-
if pg is not None:
487-
plan = planner.collective_plan(model, sharders, pg)
488-
else:
489-
plan = planner.plan(model, sharders)
490-
491-
sharded_model = DistributedModelParallel(
492-
module=copy.deepcopy(model),
493-
env=ShardingEnv.from_process_group(pg),
494-
init_data_parallel=True,
495-
device=device,
496-
sharders=sharders,
497-
plan=plan,
498-
).to(device)
499-
500-
# Get dense parameters
501-
dense_params = [
502-
param
503-
for name, param in sharded_model.named_parameters()
504-
if "sparse" not in name
505-
]
506-
507-
# Create optimizer based on the specified type
508-
optimizer_class = getattr(optim, dense_optimizer)
509-
510-
# Create optimizer with momentum and/or weight_decay if provided
511-
optimizer_kwargs = {"lr": dense_lr}
512-
513-
if dense_momentum is not None:
514-
optimizer_kwargs["momentum"] = dense_momentum
515-
516-
if dense_weight_decay is not None:
517-
optimizer_kwargs["weight_decay"] = dense_weight_decay
518-
519-
optimizer = optimizer_class(dense_params, **optimizer_kwargs)
520-
521-
return sharded_model, optimizer
522-
523-
524314
def generate_data(
525315
tables: List[EmbeddingBagConfig],
526316
weighted_tables: List[EmbeddingBagConfig],

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@
3333
DLRMConfig,
3434
generate_data,
3535
generate_pipeline,
36-
generate_planner,
37-
generate_sharded_model_and_optimizer,
38-
generate_tables,
3936
TestSparseNNConfig,
4037
TestTowerCollectionSparseNNConfig,
4138
TestTowerSparseNNConfig,
@@ -44,6 +41,9 @@
4441
benchmark_func,
4542
BenchmarkResult,
4643
cmd_conf,
44+
generate_planner,
45+
generate_sharded_model_and_optimizer,
46+
generate_tables,
4747
)
4848
from torchrec.distributed.comm import get_local_size
4949
from torchrec.distributed.embedding_types import EmbeddingComputeKernel

0 commit comments

Comments
 (0)