@@ -166,6 +166,14 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
166166 ),
167167 )
168168
169+ def __repr__ (self ) -> str :
170+ tmpstr = self .__class__ .__name__ + "("
171+ tmpstr += f"input_layouts={ self .input_layouts } , "
172+ tmpstr += f"output_layouts={ self .output_layouts } , "
173+ tmpstr += f"use_local_output={ self .use_local_output } "
174+ tmpstr += ")"
175+ return tmpstr
176+
169177
170178class RowwiseParallel (ParallelStyle ):
171179 """
@@ -303,6 +311,14 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
303311 ),
304312 )
305313
314+ def __repr__ (self ) -> str :
315+ tmpstr = self .__class__ .__name__ + "("
316+ tmpstr += f"input_layouts={ self .input_layouts } , "
317+ tmpstr += f"output_layouts={ self .output_layouts } , "
318+ tmpstr += f"use_local_output={ self .use_local_output } "
319+ tmpstr += ")"
320+ return tmpstr
321+
306322
307323class SequenceParallel (ParallelStyle ):
308324 """
@@ -398,6 +414,14 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
398414 partial (self ._prepare_output_fn , self .use_local_output ),
399415 )
400416
417+ def __repr__ (self ) -> str :
418+ tmpstr = self .__class__ .__name__ + "("
419+ if len (self .sequence_sharding ) == 1 :
420+ tmpstr += f"sequence_dim={ self .sequence_sharding [0 ].dim } , "
421+ tmpstr += f"use_local_output={ self .use_local_output } "
422+ tmpstr += ")"
423+ return tmpstr
424+
401425
402426class PrepareModuleInput (ParallelStyle ):
403427 """
@@ -557,6 +581,16 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
557581 ) # type: ignore[misc, call-arg]
558582 return module
559583
584+ def __repr__ (self ) -> str :
585+ tmpstr = self .__class__ .__name__ + "("
586+ tmpstr += f"input_layouts={ self .input_layouts } , "
587+ tmpstr += f"desired_input_layouts={ self .desired_input_layouts } , "
588+ tmpstr += f"input_kwarg_layouts={ self .input_kwarg_layouts } , "
589+ tmpstr += f"desired_input_kwarg_layouts={ self .desired_input_kwarg_layouts } , "
590+ tmpstr += f"use_local_output={ self .use_local_output } "
591+ tmpstr += ")"
592+ return tmpstr
593+
560594
561595class PrepareModuleOutput (ParallelStyle ):
562596 """
@@ -656,3 +690,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
656690 lambda _ , inputs , outputs : self ._prepare_out_fn (outputs , device_mesh )
657691 ) # type: ignore[misc, call-arg]
658692 return module
693+
694+ def __repr__ (self ) -> str :
695+ tmpstr = self .__class__ .__name__ + "("
696+ tmpstr += f"output_layouts={ self .output_layouts } , "
697+ tmpstr += f"desired_output_layouts={ self .desired_output_layouts } , "
698+ tmpstr += f"use_local_output={ self .use_local_output } "
699+ tmpstr += ")"
700+ return tmpstr
0 commit comments