Skip to content

Commit 50caae8

Browse files
Merge branch 'pytorch:main' into temp-ppc64le-wheel-branch-v8
2 parents c979154 + 1700599 commit 50caae8

File tree

8 files changed

+350
-84
lines changed

8 files changed

+350
-84
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: c10d"]
22

3+
import itertools
34
import os
45
from unittest import skipIf
56

@@ -860,55 +861,69 @@ def test_multimem_one_shot_all_reduce(
860861

861862
@skipIfRocm
862863
@skip_if_lt_x_gpu(4)
863-
@parametrize("dtype", [torch.float, torch.bfloat16])
864-
@parametrize("align_bytes", [4, 8, 16])
865-
@parametrize("size_bytes", [4, 8192, 8196])
866-
def test_one_shot_all_reduce(
867-
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
868-
) -> None:
864+
def test_one_shot_all_reduce(self) -> None:
869865
self._init_process()
870866
group_name = dist.group.WORLD.group_name
871867

872-
inp = symm_mem.empty(
873-
size_bytes // dtype.itemsize, dtype=dtype, device=self.device
874-
).normal_()
875-
symm_mem.rendezvous(inp, group=group_name)
876-
877-
res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name)
878-
self._verify_all_reduce_result(inp, res)
868+
for dtype, size_bytes, align_bytes, copy, offset in itertools.product(
869+
[torch.float, torch.bfloat16],
870+
[4, 8192, 8196],
871+
[4, 8, 16],
872+
[True, False],
873+
[0, 16],
874+
):
875+
inp = symm_mem.empty(
876+
size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device
877+
)
878+
symm_mem.rendezvous(inp, group=group_name)
879+
if not copy:
880+
inp.normal_()
881+
res = torch.ops.symm_mem.one_shot_all_reduce(
882+
inp[offset:], "sum", group_name
883+
)
884+
if copy:
885+
local_inp = torch.randn_like(inp[offset:])
886+
res = torch.ops.symm_mem.one_shot_all_reduce_copy(
887+
inp[offset:], local_inp, "sum", group_name
888+
)
889+
self._verify_all_reduce_result(local_inp if copy else inp[offset:], res)
879890

880891
dist.destroy_process_group()
881892

882893
@skipIfRocm
883894
@skip_if_lt_x_gpu(4)
884-
@parametrize("dtype", [torch.float, torch.bfloat16])
885-
@parametrize("align_bytes", [4, 8, 16])
886-
@parametrize("size_bytes", [4, 8192, 8196])
887-
def test_two_shot_all_reduce(
888-
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
889-
) -> None:
895+
def test_two_shot_all_reduce(self) -> None:
890896
self._init_process()
891897
group_name = dist.group.WORLD.group_name
892898

893-
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
894-
symm_mem.rendezvous(t, group=group_name)
895-
896-
self.assertTrue(t.data_ptr() % 16 == 0)
897-
self.assertTrue(align_bytes % t.element_size() == 0)
898-
self.assertTrue(size_bytes % t.element_size() == 0)
899-
900-
shift = align_bytes // t.element_size()
901-
numel = size_bytes // t.element_size()
902-
res = t[shift : shift + numel]
903-
res.normal_()
904-
inp = res.clone()
905-
906-
torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name)
899+
for dtype, size_bytes, align_bytes, inplace in itertools.product(
900+
[torch.float, torch.bfloat16],
901+
[4, 8192, 8196],
902+
[4, 8, 16],
903+
[True, False],
904+
):
905+
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
906+
symm_mem.rendezvous(t, group=group_name)
907+
908+
self.assertTrue(t.data_ptr() % 16 == 0)
909+
self.assertTrue(align_bytes % t.element_size() == 0)
910+
self.assertTrue(size_bytes % t.element_size() == 0)
911+
912+
shift = align_bytes // t.element_size()
913+
numel = size_bytes // t.element_size()
914+
res = t[shift : shift + numel]
915+
res.normal_().fill_(1)
916+
inp = res.clone()
917+
if not inplace:
918+
out = torch.empty_like(inp)
919+
torch.ops.symm_mem.two_shot_all_reduce_out(res, "sum", group_name, out)
920+
else:
921+
torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name)
907922

908-
# Head and tail should not be written
909-
self.assertTrue(t[:shift].eq(0).all().item())
910-
self.assertTrue(t[shift + numel :].eq(0).all().item())
911-
self._verify_all_reduce_result(inp, res)
923+
# Head and tail should not be written
924+
self.assertTrue(t[:shift].eq(0).all().item())
925+
self.assertTrue(t[shift + numel :].eq(0).all().item())
926+
self._verify_all_reduce_result(inp, res if inplace else out)
912927

913928
dist.destroy_process_group()
914929

test/dynamo/test_error_messages.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch._dynamo.test_case
1212
import torch.utils._pytree as python_pytree
1313
from torch._dynamo.exc import Unsupported
14+
from torch._dynamo.testing import skipIfNotPy312
1415
from torch._dynamo.utils import counters
1516
from torch.testing._internal.common_utils import (
1617
IS_FBCODE,
@@ -646,18 +647,42 @@ def fn():
646647
""",
647648
)
648649

649-
def test_unsupported_bytecode(self):
650+
def test_load_build_class(self):
650651
def fn():
651652
class Foo:
652653
pass
653654

654655
return Foo
655656

657+
self.assertExpectedInlineMunged(
658+
Unsupported,
659+
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
660+
"""\
661+
LOAD_BUILD_CLASS bytecode not supported
662+
Explanation: Dynamo does not support tracing classes that are defined in the compiled region.
663+
Hint: Move the class definition out of the compiled region.
664+
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
665+
666+
Developer debug context:
667+
668+
669+
from user code:
670+
File "test_error_messages.py", line N, in fn
671+
class Foo:""",
672+
)
673+
674+
@skipIfNotPy312
675+
def test_unsupported_bytecode(self):
676+
async def fn():
677+
async for i in range(3):
678+
print(i)
679+
return 1
680+
656681
def post_munge(s):
657682
s = re.sub(r"0x[0-9A-Fa-f]+", "0xmem_addr", s)
658683
s = re.sub(
659-
r"Instruction\(.*opname='LOAD_BUILD_CLASS'.*\)\n",
660-
"Instruction(LOAD_BUILD_CLASS)",
684+
r"Instruction\(.*opname='GET_AITER'.*\)\n",
685+
"Instruction(GET_AITER)",
661686
s,
662687
)
663688
return s
@@ -667,15 +692,15 @@ def post_munge(s):
667692
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
668693
"""\
669694
Missing bytecode handler
670-
Explanation: Dynamo does not know how to handle the bytecode instruction `LOAD_BUILD_CLASS`.
671-
Hint: Do not trace code that produces the `LOAD_BUILD_CLASS` bytecode instruction (see https:/docs.python.org/3/library/dis.html for bytecode semantics).
695+
Explanation: Dynamo does not know how to handle the bytecode instruction `GET_AITER`.
696+
Hint: Do not trace code that produces the `GET_AITER` bytecode instruction (see https:/docs.python.org/3/library/dis.html for bytecode semantics).
672697
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
673698
674-
Developer debug context: LOAD_BUILD_CLASS with args (<torch._dynamo.symbolic_convert.InstructionTranslator object at 0xmem_addr>, Instruction(LOAD_BUILD_CLASS)
699+
Developer debug context: GET_AITER with args (<torch._dynamo.symbolic_convert.InstructionTranslator object at 0xmem_addr>, Instruction(GET_AITER)
675700
676701
from user code:
677702
File "test_error_messages.py", line N, in fn
678-
class Foo:""",
703+
async for i in range(3):""",
679704
post_munge=post_munge,
680705
)
681706

torch/_dynamo/symbolic_convert.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2824,6 +2824,17 @@ def MATCH_KEYS(self, inst):
28242824
def LOAD_ASSERTION_ERROR(self, inst):
28252825
self.load_builtin_from_argval("AssertionError")
28262826

2827+
def LOAD_BUILD_CLASS(self, inst):
2828+
unimplemented_v2(
2829+
gb_type="LOAD_BUILD_CLASS bytecode not supported",
2830+
context="",
2831+
explanation="Dynamo does not support tracing classes that are defined in the compiled region.",
2832+
hints=[
2833+
"Move the class definition out of the compiled region.",
2834+
*graph_break_hints.SUPPORTABLE,
2835+
],
2836+
)
2837+
28272838
UNARY_POSITIVE = stack_op(operator.pos)
28282839
UNARY_NEGATIVE = stack_op(operator.neg)
28292840
UNARY_NOT = stack_op(operator.not_)

torch/_export/passes/lift_constants_pass.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# mypy: allow-untyped-defs
22
import collections
3-
import warnings
3+
import logging
44
from typing import Any, Union
55

66
import torch
@@ -19,6 +19,9 @@
1919
from torch.fx.graph_module import _get_attr
2020

2121

22+
log = logging.getLogger(__name__)
23+
24+
2225
class ConstantAttrMap(collections.abc.MutableMapping):
2326
"""A mapping class that understands how to use module constants (tensors,
2427
ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally,
@@ -213,9 +216,11 @@ def lift_constants_pass(
213216
elif isinstance(constant_val, torch.Tensor):
214217
# Remove the parameterness of constant_val
215218
if isinstance(constant_val, torch.nn.Parameter):
216-
warnings.warn(
217-
f"{node.target} created when tracing {node.meta.get('stack_trace', '<unknown stack>')} is a parameter. But"
218-
f"it's not registered with register_parameter(). export will treat it as a constant tensor"
219+
log.debug(
220+
"%s created when tracing %s is a parameter. But "
221+
"it's not registered with register_parameter(). export will treat it as a constant tensor",
222+
str(node.target),
223+
str(node.meta.get("stack_trace", "<unknown stack>")),
219224
)
220225
# We get the real data out of the parameter by disabling the surrounding fake mode.
221226
with unset_fake_temporarily():

0 commit comments

Comments
 (0)