Skip to content

Commit bf6621d

Browse files
shinkpytorchmergebot
authored andcommitted
[Distributed] Add repr methods for ParallelStyles (pytorch#149478)
Fixes pytorch#149470 Pull Request resolved: pytorch#149478 Approved by: https://github.com/wanchaol
1 parent ee6a029 commit bf6621d

File tree

1 file changed

+42
-0
lines changed
  • torch/distributed/tensor/parallel

1 file changed

+42
-0
lines changed

torch/distributed/tensor/parallel/style.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
166166
),
167167
)
168168

169+
def __repr__(self) -> str:
170+
tmpstr = self.__class__.__name__ + "("
171+
tmpstr += f"input_layouts={self.input_layouts}, "
172+
tmpstr += f"output_layouts={self.output_layouts}, "
173+
tmpstr += f"use_local_output={self.use_local_output}"
174+
tmpstr += ")"
175+
return tmpstr
176+
169177

170178
class RowwiseParallel(ParallelStyle):
171179
"""
@@ -303,6 +311,14 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
303311
),
304312
)
305313

314+
def __repr__(self) -> str:
315+
tmpstr = self.__class__.__name__ + "("
316+
tmpstr += f"input_layouts={self.input_layouts}, "
317+
tmpstr += f"output_layouts={self.output_layouts}, "
318+
tmpstr += f"use_local_output={self.use_local_output}"
319+
tmpstr += ")"
320+
return tmpstr
321+
306322

307323
class SequenceParallel(ParallelStyle):
308324
"""
@@ -398,6 +414,14 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
398414
partial(self._prepare_output_fn, self.use_local_output),
399415
)
400416

417+
def __repr__(self) -> str:
418+
tmpstr = self.__class__.__name__ + "("
419+
if len(self.sequence_sharding) == 1:
420+
tmpstr += f"sequence_dim={self.sequence_sharding[0].dim}, "
421+
tmpstr += f"use_local_output={self.use_local_output}"
422+
tmpstr += ")"
423+
return tmpstr
424+
401425

402426
class PrepareModuleInput(ParallelStyle):
403427
"""
@@ -557,6 +581,16 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
557581
) # type: ignore[misc, call-arg]
558582
return module
559583

584+
def __repr__(self) -> str:
585+
tmpstr = self.__class__.__name__ + "("
586+
tmpstr += f"input_layouts={self.input_layouts}, "
587+
tmpstr += f"desired_input_layouts={self.desired_input_layouts}, "
588+
tmpstr += f"input_kwarg_layouts={self.input_kwarg_layouts}, "
589+
tmpstr += f"desired_input_kwarg_layouts={self.desired_input_kwarg_layouts}, "
590+
tmpstr += f"use_local_output={self.use_local_output}"
591+
tmpstr += ")"
592+
return tmpstr
593+
560594

561595
class PrepareModuleOutput(ParallelStyle):
562596
"""
@@ -656,3 +690,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
656690
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
657691
) # type: ignore[misc, call-arg]
658692
return module
693+
694+
def __repr__(self) -> str:
695+
tmpstr = self.__class__.__name__ + "("
696+
tmpstr += f"output_layouts={self.output_layouts}, "
697+
tmpstr += f"desired_output_layouts={self.desired_output_layouts}, "
698+
tmpstr += f"use_local_output={self.use_local_output}"
699+
tmpstr += ")"
700+
return tmpstr

0 commit comments

Comments
 (0)