diff --git a/torchrec/distributed/planner/__init__.py b/torchrec/distributed/planner/__init__.py index efd06bf02..3dd8289e2 100644 --- a/torchrec/distributed/planner/__init__.py +++ b/torchrec/distributed/planner/__init__.py @@ -21,6 +21,9 @@ - automatically building and selecting an optimized sharding plan. """ -from torchrec.distributed.planner.planners import EmbeddingShardingPlanner # noqa +from torchrec.distributed.planner.planners import ( # noqa # noqa + EmbeddingPlannerBase, + EmbeddingShardingPlanner, +) from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index dd79751e8..fba41d4d9 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -251,6 +251,22 @@ def collective_plan( sharders, ) + def hash_planner_context_inputs(self) -> int: + """ + Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. + These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. + + Returns: + Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. + """ + return hash_planner_context_inputs( + self._topology, + self._batch_size, + self._enumerator, + self._storage_reservation, + self._constraints, + ) + class EmbeddingShardingPlanner(EmbeddingPlannerBase): """ @@ -368,22 +384,6 @@ def collective_plan( sharders, ) - def hash_planner_context_inputs(self) -> int: - """ - Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. - These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. - - Returns: - Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. - """ - return hash_planner_context_inputs( - self._topology, - self._batch_size, - self._enumerator, - self._storage_reservation, - self._constraints, - ) - def plan( self, module: nn.Module,