Skip to content

Commit ce62470

Browse files
Merge branch 'temp-ppc64le-wheel-branch-v8' of https://github.com/sandeepgupta12/pytorch into temp-ppc64le-wheel-branch-v8
2 parents e79bc7e + 6528bbf commit ce62470

File tree

13 files changed

+367
-79
lines changed

13 files changed

+367
-79
lines changed

benchmarks/dynamo/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ def load(cls, model, example_inputs):
13971397
# see https://github.com/pytorch/pytorch/issues/113029
13981398
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
13991399

1400-
if pytree._is_namedtuple_instance(example_outputs):
1400+
if pytree.is_namedtuple_instance(example_outputs):
14011401
typ = type(example_outputs)
14021402
pytree._register_namedtuple(
14031403
typ,

test/test_pytree.py

Lines changed: 149 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import subprocess
88
import sys
9+
import time
910
import unittest
1011
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
1112
from dataclasses import dataclass
@@ -731,6 +732,133 @@ def test_pytree_serialize_bad_input(self, pytree_impl):
731732
with self.assertRaises(TypeError):
732733
pytree_impl.treespec_dumps("random_blurb")
733734

735+
@parametrize(
736+
"pytree",
737+
[
738+
subtest(py_pytree, name="py"),
739+
subtest(cxx_pytree, name="cxx"),
740+
],
741+
)
742+
def test_is_namedtuple(self, pytree):
743+
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
744+
745+
class DirectNamedTuple2(NamedTuple):
746+
x: int
747+
y: int
748+
749+
class IndirectNamedTuple1(DirectNamedTuple1):
750+
pass
751+
752+
class IndirectNamedTuple2(DirectNamedTuple2):
753+
pass
754+
755+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
756+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
757+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
758+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
759+
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
760+
self.assertFalse(pytree.is_namedtuple((0, 1)))
761+
self.assertFalse(pytree.is_namedtuple([0, 1]))
762+
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
763+
self.assertFalse(pytree.is_namedtuple({0, 1}))
764+
self.assertFalse(pytree.is_namedtuple(1))
765+
766+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
767+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
768+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
769+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
770+
self.assertFalse(pytree.is_namedtuple(time.struct_time))
771+
self.assertFalse(pytree.is_namedtuple(tuple))
772+
self.assertFalse(pytree.is_namedtuple(list))
773+
774+
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
775+
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
776+
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
777+
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
778+
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
779+
self.assertFalse(pytree.is_namedtuple_class(tuple))
780+
self.assertFalse(pytree.is_namedtuple_class(list))
781+
782+
@parametrize(
783+
"pytree",
784+
[
785+
subtest(py_pytree, name="py"),
786+
subtest(cxx_pytree, name="cxx"),
787+
],
788+
)
789+
def test_is_structseq(self, pytree):
790+
class FakeStructSeq(tuple):
791+
n_fields = 2
792+
n_sequence_fields = 2
793+
n_unnamed_fields = 0
794+
795+
__slots__ = ()
796+
__match_args__ = ("x", "y")
797+
798+
def __new__(cls, sequence):
799+
return super().__new__(cls, sequence)
800+
801+
@property
802+
def x(self):
803+
return self[0]
804+
805+
@property
806+
def y(self):
807+
return self[1]
808+
809+
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
810+
811+
class DirectNamedTuple2(NamedTuple):
812+
x: int
813+
y: int
814+
815+
self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
816+
self.assertTrue(pytree.is_structseq(time.gmtime()))
817+
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
818+
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
819+
self.assertFalse(pytree.is_structseq((0, 1)))
820+
self.assertFalse(pytree.is_structseq([0, 1]))
821+
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
822+
self.assertFalse(pytree.is_structseq({0, 1}))
823+
self.assertFalse(pytree.is_structseq(1))
824+
825+
self.assertFalse(pytree.is_structseq(FakeStructSeq))
826+
self.assertTrue(pytree.is_structseq(time.struct_time))
827+
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
828+
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
829+
self.assertFalse(pytree.is_structseq(tuple))
830+
self.assertFalse(pytree.is_structseq(list))
831+
832+
self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
833+
self.assertTrue(
834+
pytree.is_structseq_class(time.struct_time),
835+
)
836+
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
837+
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
838+
self.assertFalse(pytree.is_structseq_class(tuple))
839+
self.assertFalse(pytree.is_structseq_class(list))
840+
841+
# torch.return_types.* are all PyStructSequence types
842+
for cls in vars(torch.return_types).values():
843+
if isinstance(cls, type) and issubclass(cls, tuple):
844+
self.assertTrue(pytree.is_structseq(cls))
845+
self.assertTrue(pytree.is_structseq_class(cls))
846+
self.assertFalse(pytree.is_namedtuple(cls))
847+
self.assertFalse(pytree.is_namedtuple_class(cls))
848+
849+
inst = cls(range(cls.n_sequence_fields))
850+
self.assertTrue(pytree.is_structseq(inst))
851+
self.assertTrue(pytree.is_structseq(type(inst)))
852+
self.assertFalse(pytree.is_structseq_class(inst))
853+
self.assertTrue(pytree.is_structseq_class(type(inst)))
854+
self.assertFalse(pytree.is_namedtuple(inst))
855+
self.assertFalse(pytree.is_namedtuple_class(inst))
856+
else:
857+
self.assertFalse(pytree.is_structseq(cls))
858+
self.assertFalse(pytree.is_structseq_class(cls))
859+
self.assertFalse(pytree.is_namedtuple(cls))
860+
self.assertFalse(pytree.is_namedtuple_class(cls))
861+
734862

735863
class TestPythonPytree(TestCase):
736864
def test_deprecated_register_pytree_node(self):
@@ -975,9 +1103,8 @@ def test_pytree_serialize_namedtuple(self):
9751103
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
9761104
)
9771105

978-
spec = py_pytree.TreeSpec(
979-
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
980-
)
1106+
spec = py_pytree.tree_structure(Point1(1, 2))
1107+
self.assertIs(spec.type, namedtuple)
9811108
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
9821109
self.assertEqual(spec, roundtrip_spec)
9831110

@@ -990,18 +1117,28 @@ class Point2(NamedTuple):
9901117
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
9911118
)
9921119

993-
spec = py_pytree.TreeSpec(
994-
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1120+
spec = py_pytree.tree_structure(Point2(1, 2))
1121+
self.assertIs(spec.type, namedtuple)
1122+
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
1123+
self.assertEqual(spec, roundtrip_spec)
1124+
1125+
class Point3(Point2):
1126+
pass
1127+
1128+
py_pytree._register_namedtuple(
1129+
Point3,
1130+
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3",
9951131
)
1132+
1133+
spec = py_pytree.tree_structure(Point3(1, 2))
1134+
self.assertIs(spec.type, namedtuple)
9961135
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
9971136
self.assertEqual(spec, roundtrip_spec)
9981137

9991138
def test_pytree_serialize_namedtuple_bad(self):
10001139
DummyType = namedtuple("DummyType", ["x", "y"])
10011140

1002-
spec = py_pytree.TreeSpec(
1003-
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1004-
)
1141+
spec = py_pytree.tree_structure(DummyType(1, 2))
10051142

10061143
with self.assertRaisesRegex(
10071144
NotImplementedError, "Please register using `_register_namedtuple`"
@@ -1020,9 +1157,7 @@ def __init__(self, x, y):
10201157
lambda xs, _: DummyType(*xs),
10211158
)
10221159

1023-
spec = py_pytree.TreeSpec(
1024-
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1025-
)
1160+
spec = py_pytree.tree_structure(DummyType(1, 2))
10261161
with self.assertRaisesRegex(
10271162
NotImplementedError, "No registered serialization name"
10281163
):
@@ -1042,9 +1177,7 @@ def __init__(self, x, y):
10421177
to_dumpable_context=lambda context: "moo",
10431178
from_dumpable_context=lambda dumpable_context: None,
10441179
)
1045-
spec = py_pytree.TreeSpec(
1046-
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1047-
)
1180+
spec = py_pytree.tree_structure(DummyType(1, 2))
10481181
serialized_spec = py_pytree.treespec_dumps(spec, 1)
10491182
self.assertIn("moo", serialized_spec)
10501183
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
@@ -1082,9 +1215,7 @@ def __init__(self, x, y):
10821215
from_dumpable_context=lambda dumpable_context: None,
10831216
)
10841217

1085-
spec = py_pytree.TreeSpec(
1086-
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1087-
)
1218+
spec = py_pytree.tree_structure(DummyType(1, 2))
10881219

10891220
with self.assertRaisesRegex(
10901221
TypeError, "Object of type type is not JSON serializable"
@@ -1095,9 +1226,7 @@ def test_pytree_serialize_bad_protocol(self):
10951226
import json
10961227

10971228
Point = namedtuple("Point", ["x", "y"])
1098-
spec = py_pytree.TreeSpec(
1099-
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1100-
)
1229+
spec = py_pytree.tree_structure(Point(1, 2))
11011230
py_pytree._register_namedtuple(
11021231
Point,
11031232
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",

tools/generate_torch_version.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def get_torch_version(sha: str | None = None) -> str:
9797

9898
with open(version_path, "w") as f:
9999
f.write("from typing import Optional\n\n")
100-
f.write("__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip']\n")
100+
f.write(
101+
"__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n"
102+
)
101103
f.write(f"__version__ = '{version}'\n")
102104
# NB: This is not 100% accurate, because you could have built the
103105
# library code with DEBUG, but csrc without DEBUG (in which case

torch/_dynamo/convert_frame.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,6 @@ def compile_inner(
774774
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
775775
)
776776
)
777-
stack.enter_context(
778-
_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()
779-
)
780777
stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
781778
stack.enter_context(CompileTimeInstructionCounter.record())
782779
return _compile_inner(code, one_graph, hooks, transform)
@@ -957,7 +954,9 @@ def count_args(code: CodeType) -> int:
957954
chromium_event_timed(
958955
"dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
959956
),
957+
_WaitCounter("pytorch.wait_counter.entire_forward_compile").guard(),
960958
metrics_context,
959+
_WaitCounter("pytorch.wait_counter.dynamo_compile").guard(),
961960
):
962961
restart_reasons: set[str] = set()
963962
# This is shared across restarts

torch/_dynamo/polyfills/pytree.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ def _(*args: Any, **kwargs: Any) -> bool:
5656
"structseq_fields",
5757
):
5858
__func = getattr(optree, __name)
59-
substitute_in_graph(__func, can_constant_fold_through=True)(
59+
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
6060
__func.__python_implementation__
6161
)
62+
__all__ += [__name] # noqa: PLE0604
6263
del __func
6364
del __name
6465

torch/_export/serde/serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec):
12431243
def store_namedtuple_fields(ts):
12441244
if ts.type is None:
12451245
return
1246-
if ts.type == namedtuple:
1246+
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
12471247
serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name
12481248
if serialized_type_name in self.treespec_namedtuple_fields:
12491249
field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch._prims_common import CUDARngStateHelper
3232
from torch._subclasses import FakeTensor
3333
from torch.fx.experimental._backward_state import BackwardState
34+
from torch.monitor import _WaitCounter
3435
from torch.multiprocessing.reductions import StorageWeakRef
3536
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
3637

@@ -2225,7 +2226,9 @@ def _backward_impl(ctx, all_args):
22252226
dynamo_compile_column_us="backward_cumulative_compile_time_us",
22262227
log_waitcounter=True,
22272228
waitcounter_name_override="entire_backward_compile",
2228-
):
2229+
), _WaitCounter(
2230+
"pytorch.wait_counter.dynamo_compile"
2231+
).guard():
22292232
CompileEventLogger.compilation_metric(is_forward=False)
22302233
# See Note: [Backward graph lazy lowering]
22312234
CompiledFunction.compiled_bw = aot_config.bw_compiler(

torch/_inductor/compile_fx.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,15 +620,6 @@ def compile_fx_inner(
620620
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
621621
)
622622
)
623-
# NB: Why is this the dynamo_compile counter? The rule here is that
624-
# if it gets an entry in the dynamo_compile table, we also want to
625-
# tick up the wait counter. We have to displeasingly manually trigger
626-
# the counter here because we may dropped into compile_fx directly
627-
# from lazy backwards compilation.
628-
stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard())
629-
stack.enter_context(
630-
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard()
631-
)
632623

633624
if torch._dynamo.callback_handler.prevent_duplicate_callbacks:
634625
stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
@@ -691,7 +682,6 @@ def _compile_fx_inner(
691682

692683
with (
693684
_WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _,
694-
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard(),
695685
):
696686
use_cache = (
697687
not config.force_disable_caches

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import torch
3333
from torch._prims_common import compute_required_storage_length
34+
from torch.monitor import _WaitCounter
3435
from torch.utils._ordered_set import OrderedSet
3536

3637
from ..triton_bundler import TritonBundler
@@ -815,13 +816,18 @@ def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
815816
return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
816817

817818
def benchmark_all_configs(self, *args, **kwargs):
818-
with dynamo_timed(
819-
"CachingAutotuner.benchmark_all_configs",
820-
log_pt2_compile_event=True,
821-
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
822-
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
823-
compile_id=self.compile_id,
824-
is_backward=self.is_backward,
819+
with (
820+
dynamo_timed(
821+
"CachingAutotuner.benchmark_all_configs",
822+
log_pt2_compile_event=True,
823+
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
824+
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
825+
compile_id=self.compile_id,
826+
is_backward=self.is_backward,
827+
log_waitcounter=True,
828+
waitcounter_name_override="triton_autotuner",
829+
),
830+
_WaitCounter("pytorch.wait_counter.dynamo_compile").guard(),
825831
):
826832
timings = {
827833
launcher: self.bench(launcher, *args, **kwargs)

torch/autograd/forward_ad.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# mypy: allow-untyped-defs
22
import os
3-
from collections import namedtuple
4-
from typing import Any
3+
from typing import Any, NamedTuple, Optional
54

65
import torch
76

@@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None):
129128
return torch._VF._make_dual(tensor, tangent, level=level)
130129

131130

132-
_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])
133-
134-
135-
class UnpackedDualTensor(_UnpackedDualTensor):
131+
class UnpackedDualTensor(NamedTuple):
136132
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
137133
138134
See :func:`unpack_dual` for more details.
139-
140135
"""
141136

137+
primal: torch.Tensor
138+
tangent: Optional[torch.Tensor]
139+
142140

143141
def unpack_dual(tensor, *, level=None):
144142
r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.

0 commit comments

Comments
 (0)