We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d501c48 commit 5d188e8Copy full SHA for 5d188e8
torchtnt/utils/device_mesh.py
@@ -79,6 +79,19 @@ def tp_mesh(self) -> Optional[DeviceMesh]:
79
return None
80
81
82
+def get_dp_mesh(global_mesh: GlobalMeshCoordinator) -> DeviceMesh:
83
+ """
84
+ Retrieves the data parallel mesh from the global mesh coordinator.
85
+
86
+ Args:
87
+ global_mesh (GlobalMeshCoordinator): The global mesh coordinator instance.
88
89
+ Returns:
90
+ DeviceMesh: The data parallel mesh.
91
92
+ return global_mesh.dp_mesh
93
94
95
def get_dp_mesh_size(global_mesh: GlobalMeshCoordinator) -> int:
96
"""
97
Retrieves the size of the data parallel mesh from the global mesh coordinator.
0 commit comments