|
16 | 16 | 3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py
|
17 | 17 | """
|
18 | 18 |
|
19 |
| -import copy |
20 | 19 | from abc import ABC, abstractmethod
|
21 | 20 | from dataclasses import dataclass, fields
|
22 |
| -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union |
| 21 | +from typing import Dict, List, Optional, Type, Union |
23 | 22 |
|
24 | 23 | import torch
|
25 |
| -import torch.distributed as dist |
26 |
| -from torch import nn, optim |
27 |
| -from torch.optim import Optimizer |
28 |
| -from torchrec.distributed import DistributedModelParallel |
29 |
| -from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
30 |
| -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology |
31 |
| -from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR |
32 |
| -from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner |
33 |
| -from torchrec.distributed.planner.types import ParameterConstraints |
| 24 | +from torch import nn |
34 | 25 | from torchrec.distributed.test_utils.test_input import ModelInput
|
35 | 26 | from torchrec.distributed.test_utils.test_model import (
|
36 |
| - TestEBCSharder, |
37 | 27 | TestSparseNN,
|
38 | 28 | TestTowerCollectionSparseNN,
|
39 | 29 | TestTowerSparseNN,
|
|
47 | 37 | PrefetchTrainPipelineSparseDist,
|
48 | 38 | TrainPipelineSemiSync,
|
49 | 39 | )
|
50 |
| -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType |
51 | 40 | from torchrec.models.deepfm import SimpleDeepFMNNWrapper
|
52 | 41 | from torchrec.models.dlrm import DLRMWrapper
|
53 | 42 | from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
@@ -248,55 +237,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
|
248 | 237 | return model_class(**filtered_kwargs)
|
249 | 238 |
|
250 | 239 |
|
251 |
| -def generate_tables( |
252 |
| - num_unweighted_features: int, |
253 |
| - num_weighted_features: int, |
254 |
| - embedding_feature_dim: int, |
255 |
| -) -> Tuple[ |
256 |
| - List[EmbeddingBagConfig], |
257 |
| - List[EmbeddingBagConfig], |
258 |
| -]: |
259 |
| - """ |
260 |
| - Generate embedding bag configurations for both unweighted and weighted features. |
261 |
| -
|
262 |
| - This function creates two lists of EmbeddingBagConfig objects: |
263 |
| - 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" |
264 |
| - 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" |
265 |
| -
|
266 |
| - For both types, the number of embeddings scales with the feature index, |
267 |
| - calculated as max(i + 1, 100) * 1000. |
268 |
| -
|
269 |
| - Args: |
270 |
| - num_unweighted_features (int): Number of unweighted features to generate. |
271 |
| - num_weighted_features (int): Number of weighted features to generate. |
272 |
| - embedding_feature_dim (int): Dimension of the embedding vectors. |
273 |
| -
|
274 |
| - Returns: |
275 |
| - Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing |
276 |
| - two lists - the first for unweighted embedding tables and the second for |
277 |
| - weighted embedding tables. |
278 |
| - """ |
279 |
| - tables = [ |
280 |
| - EmbeddingBagConfig( |
281 |
| - num_embeddings=max(i + 1, 100) * 1000, |
282 |
| - embedding_dim=embedding_feature_dim, |
283 |
| - name="table_" + str(i), |
284 |
| - feature_names=["feature_" + str(i)], |
285 |
| - ) |
286 |
| - for i in range(num_unweighted_features) |
287 |
| - ] |
288 |
| - weighted_tables = [ |
289 |
| - EmbeddingBagConfig( |
290 |
| - num_embeddings=max(i + 1, 100) * 1000, |
291 |
| - embedding_dim=embedding_feature_dim, |
292 |
| - name="weighted_table_" + str(i), |
293 |
| - feature_names=["weighted_feature_" + str(i)], |
294 |
| - ) |
295 |
| - for i in range(num_weighted_features) |
296 |
| - ] |
297 |
| - return tables, weighted_tables |
298 |
| - |
299 |
| - |
300 | 240 | def generate_pipeline(
|
301 | 241 | pipeline_type: str,
|
302 | 242 | emb_lookup_stream: str,
|
@@ -371,156 +311,6 @@ def generate_pipeline(
|
371 | 311 | return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)
|
372 | 312 |
|
373 | 313 |
|
374 |
| -def generate_planner( |
375 |
| - planner_type: str, |
376 |
| - topology: Topology, |
377 |
| - tables: Optional[List[EmbeddingBagConfig]], |
378 |
| - weighted_tables: Optional[List[EmbeddingBagConfig]], |
379 |
| - sharding_type: ShardingType, |
380 |
| - compute_kernel: EmbeddingComputeKernel, |
381 |
| - batch_sizes: List[int], |
382 |
| - pooling_factors: Optional[List[float]], |
383 |
| - num_poolings: Optional[List[float]], |
384 |
| -) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: |
385 |
| - """ |
386 |
| - Generate an embedding sharding planner based on the specified configuration. |
387 |
| -
|
388 |
| - Args: |
389 |
| - planner_type: Type of planner to use ("embedding" or "hetero") |
390 |
| - topology: Network topology for distributed training |
391 |
| - tables: List of unweighted embedding tables |
392 |
| - weighted_tables: List of weighted embedding tables |
393 |
| - sharding_type: Strategy for sharding embedding tables |
394 |
| - compute_kernel: Compute kernel to use for embedding tables |
395 |
| - batch_sizes: Sizes of each batch |
396 |
| - pooling_factors: Pooling factors for each feature of the table |
397 |
| - num_poolings: Number of poolings for each feature of the table |
398 |
| -
|
399 |
| - Returns: |
400 |
| - An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner |
401 |
| -
|
402 |
| - Raises: |
403 |
| - RuntimeError: If an unknown planner type is specified |
404 |
| - """ |
405 |
| - # Create parameter constraints for tables |
406 |
| - constraints = {} |
407 |
| - num_batches = len(batch_sizes) |
408 |
| - |
409 |
| - if pooling_factors is None: |
410 |
| - pooling_factors = [POOLING_FACTOR] * num_batches |
411 |
| - |
412 |
| - if num_poolings is None: |
413 |
| - num_poolings = [NUM_POOLINGS] * num_batches |
414 |
| - |
415 |
| - assert ( |
416 |
| - len(pooling_factors) == num_batches and len(num_poolings) == num_batches |
417 |
| - ), "The length of pooling_factors and num_poolings must match the number of batches." |
418 |
| - |
419 |
| - if tables is not None: |
420 |
| - for table in tables: |
421 |
| - constraints[table.name] = ParameterConstraints( |
422 |
| - sharding_types=[sharding_type.value], |
423 |
| - compute_kernels=[compute_kernel.value], |
424 |
| - device_group="cuda", |
425 |
| - pooling_factors=pooling_factors, |
426 |
| - num_poolings=num_poolings, |
427 |
| - batch_sizes=batch_sizes, |
428 |
| - ) |
429 |
| - |
430 |
| - if weighted_tables is not None: |
431 |
| - for table in weighted_tables: |
432 |
| - constraints[table.name] = ParameterConstraints( |
433 |
| - sharding_types=[sharding_type.value], |
434 |
| - compute_kernels=[compute_kernel.value], |
435 |
| - device_group="cuda", |
436 |
| - pooling_factors=pooling_factors, |
437 |
| - num_poolings=num_poolings, |
438 |
| - batch_sizes=batch_sizes, |
439 |
| - is_weighted=True, |
440 |
| - ) |
441 |
| - |
442 |
| - if planner_type == "embedding": |
443 |
| - return EmbeddingShardingPlanner( |
444 |
| - topology=topology, |
445 |
| - constraints=constraints if constraints else None, |
446 |
| - ) |
447 |
| - elif planner_type == "hetero": |
448 |
| - topology_groups = {"cuda": topology} |
449 |
| - return HeteroEmbeddingShardingPlanner( |
450 |
| - topology_groups=topology_groups, |
451 |
| - constraints=constraints if constraints else None, |
452 |
| - ) |
453 |
| - else: |
454 |
| - raise RuntimeError(f"Unknown planner type: {planner_type}") |
455 |
| - |
456 |
| - |
457 |
| -def generate_sharded_model_and_optimizer( |
458 |
| - model: nn.Module, |
459 |
| - sharding_type: str, |
460 |
| - kernel_type: str, |
461 |
| - pg: dist.ProcessGroup, |
462 |
| - device: torch.device, |
463 |
| - fused_params: Dict[str, Any], |
464 |
| - dense_optimizer: str, |
465 |
| - dense_lr: float, |
466 |
| - dense_momentum: Optional[float], |
467 |
| - dense_weight_decay: Optional[float], |
468 |
| - planner: Optional[ |
469 |
| - Union[ |
470 |
| - EmbeddingShardingPlanner, |
471 |
| - HeteroEmbeddingShardingPlanner, |
472 |
| - ] |
473 |
| - ] = None, |
474 |
| -) -> Tuple[nn.Module, Optimizer]: |
475 |
| - |
476 |
| - sharder = TestEBCSharder( |
477 |
| - sharding_type=sharding_type, |
478 |
| - kernel_type=kernel_type, |
479 |
| - fused_params=fused_params, |
480 |
| - ) |
481 |
| - sharders = [cast(ModuleSharder[nn.Module], sharder)] |
482 |
| - |
483 |
| - # Use planner if provided |
484 |
| - plan = None |
485 |
| - if planner is not None: |
486 |
| - if pg is not None: |
487 |
| - plan = planner.collective_plan(model, sharders, pg) |
488 |
| - else: |
489 |
| - plan = planner.plan(model, sharders) |
490 |
| - |
491 |
| - sharded_model = DistributedModelParallel( |
492 |
| - module=copy.deepcopy(model), |
493 |
| - env=ShardingEnv.from_process_group(pg), |
494 |
| - init_data_parallel=True, |
495 |
| - device=device, |
496 |
| - sharders=sharders, |
497 |
| - plan=plan, |
498 |
| - ).to(device) |
499 |
| - |
500 |
| - # Get dense parameters |
501 |
| - dense_params = [ |
502 |
| - param |
503 |
| - for name, param in sharded_model.named_parameters() |
504 |
| - if "sparse" not in name |
505 |
| - ] |
506 |
| - |
507 |
| - # Create optimizer based on the specified type |
508 |
| - optimizer_class = getattr(optim, dense_optimizer) |
509 |
| - |
510 |
| - # Create optimizer with momentum and/or weight_decay if provided |
511 |
| - optimizer_kwargs = {"lr": dense_lr} |
512 |
| - |
513 |
| - if dense_momentum is not None: |
514 |
| - optimizer_kwargs["momentum"] = dense_momentum |
515 |
| - |
516 |
| - if dense_weight_decay is not None: |
517 |
| - optimizer_kwargs["weight_decay"] = dense_weight_decay |
518 |
| - |
519 |
| - optimizer = optimizer_class(dense_params, **optimizer_kwargs) |
520 |
| - |
521 |
| - return sharded_model, optimizer |
522 |
| - |
523 |
| - |
524 | 314 | def generate_data(
|
525 | 315 | tables: List[EmbeddingBagConfig],
|
526 | 316 | weighted_tables: List[EmbeddingBagConfig],
|
|
0 commit comments