You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support buffer size argument in broadcast_str util (#988)
Summary:
Pull Request resolved: #988
This commit adds a new test case to the DistributedTest class to test the
broadcast_str function with a fixed buffer size. The test case checks that
the broadcasted value is correct when the fixed buffer size is larger than,
equal to, and smaller than the length of the input string.
Note: This commit also includes some minor changes to the existing test cases
to make them more robust.
Changes:
- Added new test case test_broadcast_str_fixed_buffer_size to DistributedTest
- Updated existing test cases to use spawn_multi_process instead of spawn
Error:
The test case is currently failing due to an "enforce fail" error in the Gloo
backend. Further investigation is needed to determine the root cause of this
error.
Reviewed By: JKSenthil
Differential Revision: D72077879
fbshipit-source-id: cc334a6fc04371d3cc7da9f583d4fec3221dad59
Copy file name to clipboardExpand all lines: torchtnt/utils/distributed.py
+37-12Lines changed: 37 additions & 12 deletions
Original file line number
Diff line number
Diff line change
@@ -688,6 +688,7 @@ def broadcast_str(
688
688
val: Optional[str],
689
689
src: int=0,
690
690
process_group: Optional[dist.ProcessGroup] =None,
691
+
fixed_buffer_size: Optional[int] =None,
691
692
) ->Optional[str]:
692
693
"""
693
694
Broadcasts a string from a source rank to all other ranks in a process group.
@@ -698,18 +699,23 @@ def broadcast_str(
698
699
val: the string to broadcast
699
700
src: the source rank to broadcast from
700
701
process_group: the process group to broadcast in. Defaults to the WORLD process group.
702
+
fixed_buffer_size (int, optional): The fixed buffer size to use. Defaults to none.
703
+
If provided, it reduces the number of collective calls by padding the string to a fixed length.
701
704
702
705
Returns:
703
706
The broadcasted string.
704
707
705
708
Note:
706
709
This function issues two collective calls, one to broadcast the size of the serialized string and
707
-
one to broadcast the string itself. This can theoretically be limited to one collective call
708
-
by hardcoding maximum buffer size to use, and filling unused buffer slots with preselected
709
-
null tokens. However, this is not implemented to avoid unnecessary complexity.
710
+
one to broadcast the string itself. If you want to avoid two collective calls, you can pass a fixed_buffer_size
711
+
parameter. This will cause the string to be padded to the fixed length and only one broadcast will be performed.
712
+
However, this comes with the cost of extra memory usage.
713
+
If the string length is less than the buffer size, src rank will terminate early. However, receiving ranks may see collective hang, as expecting data from src rank. Please ensure the buffer size is large enough to avoid this issue.
0 commit comments