11#include < torch/csrc/fx/node.h>
22
3- #include < c10/util/SmallVector.h>
43#include < structmember.h>
54#include < torch/csrc/utils/object_ptr.h>
65#include < torch/csrc/utils/pythoncapi_compat.h>
7- #include < algorithm>
86
97namespace {
108
11- using NodeSortKey = c10::SmallVector<int64_t , 4 >;
129struct NodeBase ;
1310
1411// Thrown to exit out of a C++ function and return an error to Python.
@@ -166,22 +163,7 @@ struct NodeBase {
166163 PyObject* users;
167164 PyObject* _repr_fn;
168165 PyObject* meta;
169- // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
170- alignas (NodeSortKey) char sort_key_buf[sizeof (NodeSortKey)];
171-
172- inline NodeSortKey& sort_key () {
173- return *reinterpret_cast <NodeSortKey*>(sort_key_buf);
174- }
175-
176- // Equivalent to:
177- // p, n = self._prev, self._next
178- // p._next, n._prev = n, p
179- inline void remove_from_list () {
180- NodeBase* p = this ->_prev ;
181- NodeBase* n = this ->_next ;
182- p->_next = n;
183- n->_prev = p;
184- }
166+ PyObject* _sort_key;
185167};
186168
187169static PyObject* NodeBase_new (
@@ -191,8 +173,6 @@ static PyObject* NodeBase_new(
191173 PyObject* self = type->tp_alloc (type, 0 );
192174 if (!self)
193175 return nullptr ;
194- new (reinterpret_cast <NodeBase*>(self)->sort_key_buf )
195- NodeSortKey (); // placement new does not allocate
196176 return self;
197177}
198178
@@ -221,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
221201 self->users = PyDict_New ();
222202 self->_repr_fn = Py_NewRef (Py_None);
223203 self->meta = PyDict_New ();
204+ self->_sort_key = PyTuple_New (0 );
224205 return 0 ;
225206}
226207
@@ -240,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = {
240221 {" users" , T_OBJECT_EX, offsetof (NodeBase, users), 0 , nullptr },
241222 {" _repr_fn" , T_OBJECT_EX, offsetof (NodeBase, _repr_fn), 0 , nullptr },
242223 {" meta" , T_OBJECT_EX, offsetof (NodeBase, meta), 0 , nullptr },
224+ {" _sort_key" , T_OBJECT_EX, offsetof (NodeBase, _sort_key), 0 , nullptr },
243225 {nullptr } /* Sentinel */
244226};
245227
@@ -257,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
257239 Py_VISIT (self->users );
258240 Py_VISIT (self->_repr_fn );
259241 Py_VISIT (self->meta );
242+ Py_VISIT (self->_sort_key );
260243 return 0 ;
261244}
262245
@@ -274,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) {
274257 Py_CLEAR (self->users );
275258 Py_CLEAR (self->_repr_fn );
276259 Py_CLEAR (self->meta );
260+ Py_CLEAR (self->_sort_key );
277261 return 0 ;
278262}
279263
280264static void NodeBase_dealloc (PyObject* self) {
281265 PyObject_GC_UnTrack (self);
282- reinterpret_cast <NodeBase*>(self)->sort_key ().~NodeSortKey ();
283266 (void )NodeBase_clear ((NodeBase*)self);
284267 Py_TYPE (self)->tp_free (self);
285268}
@@ -338,191 +321,15 @@ static PyObject* NodeBase__update_args_kwargs(
338321 }
339322}
340323
341- static PyObject* NodeBase__remove_from_list (
342- PyObject* self,
343- PyObject* _ignored) {
344- reinterpret_cast <NodeBase*>(self)->remove_from_list ();
345- Py_RETURN_NONE;
346- }
347-
348- static PyObject* NodeBase__prepend (PyObject* self_, PyObject* arg) {
349- if (self_ == arg) {
350- Py_RETURN_NONE;
351- }
352- if (!is_node (arg)) {
353- PyErr_SetString (PyExc_TypeError, " _prepend() argument must be a Node" );
354- return nullptr ;
355- }
356- NodeBase* self = reinterpret_cast <NodeBase*>(self_);
357- NodeBase* x = reinterpret_cast <NodeBase*>(arg);
358- if (self->graph != x->graph ) {
359- PyErr_SetString (
360- PyExc_AssertionError,
361- " Attempting to move a Node into a different Graph" );
362- return nullptr ;
363- }
364-
365- x->remove_from_list ();
366- NodeBase* p = self->_prev ;
367- p->_next = x;
368- x->_prev = p;
369- x->_next = self;
370- self->_prev = x;
371-
372- // Now compute x.sort_key()
373- const NodeSortKey& psk = x->_prev ->sort_key ();
374- const NodeSortKey& nsk = x->_next ->sort_key ();
375- if (psk.size () > nsk.size ()) {
376- // prefix = psk[: len(nsk)+1]
377- size_t slice_len = nsk.size () + 1 ;
378- NodeSortKey prefix (psk.begin (), psk.begin () + slice_len);
379- // last element is idx => increment by 1
380- prefix.back ()++;
381- x->sort_key () = std::move (prefix);
382- } else if (psk.size () < nsk.size ()) {
383- // prefix = nsk[: len(psk)+1]
384- size_t slice_len = psk.size () + 1 ;
385- NodeSortKey prefix (nsk.begin (), nsk.begin () + slice_len);
386- // last element is idx => decrement by 1
387- prefix.back ()--;
388- x->sort_key () = std::move (prefix);
389- } else {
390- // same length => add a 0
391- x->sort_key () = psk;
392- x->sort_key ().emplace_back (0 );
393- }
394- Py_RETURN_NONE;
395- }
396-
397- // __lt__(self, other): Return self.sort_key < other.sort_key
398- static PyObject* NodeBase___lt__ (PyObject* self, PyObject* other) {
399- // METH_O => one argument: 'other'
400- if (!is_node (other)) {
401- Py_RETURN_NOTIMPLEMENTED;
402- }
403- const NodeSortKey& lhs = reinterpret_cast <NodeBase*>(self)->sort_key ();
404- const NodeSortKey& rhs = reinterpret_cast <NodeBase*>(other)->sort_key ();
405- bool less = std::lexicographical_compare (
406- lhs.begin (), lhs.end (), rhs.begin (), rhs.end ());
407- if (less)
408- Py_RETURN_TRUE;
409- Py_RETURN_FALSE;
410- }
411-
412- // __gt__(self, other): Return self.sort_key() > other.sort_key
413- static PyObject* NodeBase___gt__ (PyObject* self, PyObject* other) {
414- if (!is_node (other)) {
415- Py_RETURN_NOTIMPLEMENTED;
416- }
417- const NodeSortKey& lhs = reinterpret_cast <NodeBase*>(self)->sort_key ();
418- const NodeSortKey& rhs = reinterpret_cast <NodeBase*>(other)->sort_key ();
419- // "a > b" is equivalent to "b < a"
420- bool greater = std::lexicographical_compare (
421- rhs.begin (), rhs.end (), lhs.begin (), lhs.end ());
422- if (greater)
423- Py_RETURN_TRUE;
424- Py_RETURN_FALSE;
425- }
426-
427- static PyObject* NodeBase___ge__ (PyObject* self, PyObject* other) {
428- if (self == other) {
429- Py_RETURN_TRUE;
430- }
431- return NodeBase___gt__ (self, other);
432- }
433-
434- // __le__(self, other): Return not (self > other)
435- static PyObject* NodeBase___le__ (PyObject* self, PyObject* other) {
436- if (self == other) {
437- Py_RETURN_TRUE;
438- }
439- return NodeBase___lt__ (self, other);
440- }
441-
442- // Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
443- // Only used by pickle/__getstate__
444- static PyObject* NodeBase_get_sort_key (PyObject* self, void * /* closure*/ ) {
445- NodeBase* node = reinterpret_cast <NodeBase*>(self);
446- const NodeSortKey& vec = node->sort_key ();
447- Py_ssize_t n = static_cast <Py_ssize_t>(vec.size ());
448- THPObjectPtr tuple (PyTuple_New (n));
449- if (!tuple) {
450- return nullptr ; // Out of memory
451- }
452- for (Py_ssize_t i = 0 ; i < n; i++) {
453- PyTuple_SET_ITEM (tuple.get (), i, PyLong_FromSsize_t (vec[i]));
454- }
455- return tuple.release ();
456- }
457-
458- // Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
459- // node._sort_key = (1,2,3) Only used by pickle/__setstate__
460- static int NodeBase_set_sort_key (
461- PyObject* self,
462- PyObject* value,
463- void * /* closure*/ ) {
464- NodeBase* node = reinterpret_cast <NodeBase*>(self);
465- if (!PyTuple_Check (value)) {
466- PyErr_SetString (PyExc_TypeError, " _sort_key must be an tuple of ints" );
467- return -1 ;
468- }
469- Py_ssize_t size = PyTuple_GET_SIZE (value);
470- NodeSortKey new_vec;
471- new_vec.reserve (size);
472- for (Py_ssize_t i = 0 ; i < size; i++) {
473- int64_t val = PyLong_AsSsize_t (PyTuple_GET_ITEM (value, i));
474- if (val == -1 && PyErr_Occurred ()) {
475- return -1 ;
476- }
477- new_vec.emplace_back (val);
478- }
479- node->sort_key () = std::move (new_vec);
480- return 0 ;
481- }
482-
483324// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
484325static PyMethodDef NodeBase_methods[] = {
485326 {" _update_args_kwargs" ,
486327 (PyCFunction)(void *)(NodeBase__update_args_kwargs),
487328 METH_FASTCALL,
488329 " Internal method: do not call directly." },
489- {" _remove_from_list" ,
490- (PyCFunction)(void *)(NodeBase__remove_from_list),
491- METH_NOARGS,
492- " Internal method: do not call directly." },
493- {" _prepend" ,
494- (PyCFunction)(void *)(NodeBase__prepend),
495- METH_O,
496- " Internal method: do not call directly." },
497- {" __lt__" ,
498- (PyCFunction)(void *)NodeBase___lt__,
499- METH_O,
500- " Return True if self.sort_key < other.sort_key" },
501- {" __gt__" ,
502- (PyCFunction)(void *)NodeBase___gt__,
503- METH_O,
504- " Return True if self.sort_key > other.sort_key" },
505- {" __ge__" ,
506- (PyCFunction)(void *)NodeBase___ge__,
507- METH_O,
508- " Return True if self.sort_key >= other.sort_key" },
509- {" __le__" ,
510- (PyCFunction)(void *)NodeBase___le__,
511- METH_O,
512- " Return True if self.sort_key <= other.sort_key" },
513330 {nullptr , nullptr , 0 , nullptr } // Sentinel
514331};
515332
516- // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
517- static PyGetSetDef NodeBase_getset[] = {
518- {" _sort_key" , // attribute name in Python
519- (getter)NodeBase_get_sort_key, // C getter function
520- (setter)NodeBase_set_sort_key, // C setter function
521- (char *)" The sort key as a tuple of ints" , // docstring
522- nullptr },
523- {nullptr , nullptr , nullptr , nullptr , nullptr } // Sentinel
524- };
525-
526333PyTypeObject NodeBaseType = {
527334 PyVarObject_HEAD_INIT (nullptr , 0 )
528335 " torch._C._NodeBase" , /* tp_name */
@@ -554,7 +361,7 @@ PyTypeObject NodeBaseType = {
554361 nullptr , /* tp_iternext */
555362 NodeBase_methods, /* tp_methods */
556363 NodeBase_members, /* tp_members */
557- NodeBase_getset , /* tp_getset */
364+ nullptr , /* tp_getset */
558365 nullptr , /* tp_base */
559366 nullptr , /* tp_dict */
560367 nullptr , /* tp_descr_get */
0 commit comments