Skip to content

Commit 7354422

Browse files
committed
Fix vector operators
According to https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types operators should return NotImplemented This also allows to use vpythons vectors for custom classes
1 parent 5dbae25 commit 7354422

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

vpython/cyvector.pyx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,26 @@ cdef class vector(object):
6262
return '<{:.6g}, {:.6g}, {:.6g}>'.format(self._x, self._y, self._z)
6363

6464
def __add__(self,other):
65-
return vector(self._x + other._x, self._y + other._y, self._z + other._z)
65+
if type(other) is vector:
66+
return vector(self._x + other._x, self._y + other._y, self._z + other._z)
67+
return NotImplemented
6668

6769
def __truediv__(self, other): # Python 3, or Python 2 + future division
6870
if isinstance(other, (int, float)):
6971
return vector(self._x / other, self._y / other, self._z / other)
70-
raise TypeError('a vector can only be divided by a scalar')
72+
return NotImplemented
7173

7274
def __sub__(self,other):
73-
return vector(self._x - other._x, self._y - other._y, self._z - other._z)
75+
if type(other) is vector:
76+
return vector(self._x - other._x, self._y - other._y, self._z - other._z)
77+
return NotImplemented
7478

7579
def __mul__(self, other): ## in cython order of arguments is arbitrary, rmul doesn't exist
7680
if isinstance(other, (int, float)):
7781
return vector(self._x * other, self._y * other, self._z * other)
7882
elif isinstance(self, (int, float)):
7983
return vector(self * other._x, self * other._y, self * other._z)
80-
else:
81-
raise TypeError('a vector can only be multiplied by a scalar', self, other)
84+
return NotImplemented
8285

8386
def __eq__(self,other):
8487
if type(self) is vector and type(other) is vector:

vpython/vector.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,29 +53,30 @@ def __str__(self):
5353
def __repr__(self):
5454
return '<{:.6g}, {:.6g}, {:.6g}>'.format(self._x, self._y, self._z)
5555

56-
def __add__(self,other):
57-
return vector(self._x + other._x, self._y + other._y, self._z + other._z)
56+
def __add__(self, other):
57+
if type(other) is vector:
58+
return vector(self._x + other._x, self._y + other._y, self._z + other._z)
59+
return NotImplemented
5860

59-
def __sub__(self,other):
60-
return vector(self._x - other._x, self._y - other._y, self._z - other._z)
61+
def __sub__(self, other):
62+
if type(other) is vector:
63+
return vector(self._x - other._x, self._y - other._y, self._z - other._z)
64+
return NotImplemented
6165

6266
def __truediv__(self, other): # used by Python 3, and by Python 2 in the presence of __future__ division
63-
try:
67+
if isinstance(other, (float, int)):
6468
return vector(self._x / other, self._y / other, self._z / other)
65-
except:
66-
raise TypeError('a vector can only be divided by a scalar')
69+
return NotImplemented
6770

6871
def __mul__(self, other):
69-
try:
72+
if isinstance(other, (float, int)):
7073
return vector(self._x * other, self._y * other, self._z * other)
71-
except:
72-
raise TypeError('a vector can only be multiplied by a scalar')
74+
return NotImplemented
7375

7476
def __rmul__(self, other):
75-
try:
77+
if isinstance(other, (float, int)):
7678
return vector(self._x * other, self._y * other, self._z * other)
77-
except:
78-
raise TypeError('a vector can only be multiplied by a scalar')
79+
return NotImplemented
7980

8081
def __eq__(self,other):
8182
if type(self) is vector and type(other) is vector:

0 commit comments

Comments
 (0)