Skip to content

Commit e353dda

Browse files
Merge pull request #136 from BruceSherwood/Implement_==_and_!=_for_vectors
Implement == and != for vectors
2 parents fafaca0 + e40cff2 commit e353dda

File tree

3 files changed

+105
-85
lines changed

3 files changed

+105
-85
lines changed

vpython/cyvector.pyx

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ cdef class vector(object):
3131
self._z = args[2]
3232
elif len(args) == 1 and isinstance(args[0], vector): # make a copy of a vector
3333
other = args[0]
34-
self._x = other.x
35-
self._y = other.y
36-
self._z = other.z
34+
self._x = other._x
35+
self._y = other._y
36+
self._z = other._z
3737
else:
3838
raise TypeError('A vector needs 3 components.')
3939
self.on_change = self.ignore
@@ -43,14 +43,14 @@ cdef class vector(object):
4343

4444
property value:
4545
def __get__(self):
46-
return [self.x, self.y, self.z]
46+
return [self._x, self._y, self._z]
4747
def __set__(self, other):
48-
self._x = other.x
49-
self._y = other.y
50-
self._z = other.z
48+
self._x = other._x
49+
self._y = other._y
50+
self._z = other._z
5151

5252
def __neg__(self): ## seems like this must come before properties (???)
53-
return vector(-self.x, -self.y, -self.z)
53+
return vector(-self._x, -self._y, -self._z)
5454

5555
def __pos__(self):
5656
return self
@@ -62,24 +62,34 @@ cdef class vector(object):
6262
return '<{:.6g}, {:.6g}, {:.6g}>'.format(self._x, self._y, self._z)
6363

6464
def __add__(self,other):
65-
return vector(self.x + other.x, self.y + other.y, self.z + other.z)
65+
return vector(self._x + other._x, self._y + other._y, self._z + other._z)
6666

6767
def __truediv__(self, other): # Python 3, or Python 2 + future division
6868
if isinstance(other, (int, float)):
69-
return vector(self.x / other, self.y / other, self.z / other)
69+
return vector(self._x / other, self._y / other, self._z / other)
7070
raise TypeError('a vector can only be divided by a scalar')
7171

7272
def __sub__(self,other):
73-
return vector(self.x - other.x, self.y - other.y, self.z - other.z)
73+
return vector(self._x - other._x, self._y - other._y, self._z - other._z)
7474

7575
def __mul__(self, other): ## in cython order of arguments is arbitrary, rmul doesn't exist
7676
if isinstance(other, (int, float)):
77-
return vector(self.x * other, self.y * other, self. z * other)
77+
return vector(self._x * other, self._y * other, self._z * other)
7878
elif isinstance(self, (int, float)):
79-
return vector(self * other.x, self * other.y, self * other.z)
79+
return vector(self * other._x, self * other._y, self * other._z)
8080
else:
8181
raise TypeError('a vector can only be multiplied by a scalar', self, other)
8282

83+
def __eq__(self,other):
84+
if type(self) is vector and type(other) is vector:
85+
return self.equals(other)
86+
return False
87+
88+
def __ne__(self,other):
89+
if type(self) is vector and type(other) is vector:
90+
return not self.equals(other)
91+
return True
92+
8393
property x:
8494
def __get__(self):
8595
return self._x
@@ -103,18 +113,18 @@ cdef class vector(object):
103113

104114
property mag:
105115
def __get__(self):
106-
return sqrt(self.x**2 + self.y**2 + self.z**2)
116+
return sqrt(self._x**2 + self._y**2 + self._z**2)
107117
def __set__(self, value):
108118
cdef vector normA
109119
normA = self.hat
110-
self.x = value * normA.x
111-
self.y = value * normA.y
112-
self.z = value * normA.z
120+
self.x = value * normA._x
121+
self.y = value * normA._y
122+
self.z = value * normA.vz
113123
self.on_change()
114124

115125
property mag2:
116126
def __get__(self):
117-
return (self.x**2 + self.y**2 + self.z**2)
127+
return (self._x**2 + self._y**2 + self._z**2)
118128
def __set__(self, value):
119129
cdef double v
120130
v = sqrt(value)
@@ -134,30 +144,30 @@ cdef class vector(object):
134144
smag = self.mag
135145
cdef vector normA
136146
normA = value.hat
137-
self.x = smag * normA.x
138-
self.y = smag * normA.y
139-
self.z = smag * normA.z
147+
self.x = smag * normA._x
148+
self.y = smag * normA._y
149+
self.z = smag * normA._z
140150
self.on_change()
141151

142152

143153
cpdef vector norm(self):
144154
return self.hat
145155

146156
cpdef double dot(self,other):
147-
return ( self.x*other.x + self.y*other.y + self.z*other.z )
157+
return ( self._x*other._x + self._y*other._y + self._z*other._z )
148158

149159
cpdef vector cross(self,other):
150-
return vector( self.y*other.z-self.z*other.y,
151-
self.z*other.x-self.x*other.z,
152-
self.x*other.y-self.y*other.x )
160+
return vector( self._y*other._z-self._z*other._y,
161+
self._z*other._x-self._x*other._z,
162+
self._x*other._y-self._y*other._x )
153163

154164
cpdef vector proj(self,other):
155165
cdef vector normB
156166
normB = other.hat
157167
return self.dot(normB) * normB
158168

159169
cpdef bint equals(self,other):
160-
return ( self.x == other.x and self.y == other.y and self.z == other.z )
170+
return ( self._x == other._x and self._y == other._y and self._z == other._z )
161171

162172
cpdef double comp(self,other): ## result is a scalar
163173
cdef vector normB
@@ -175,16 +185,16 @@ cdef class vector(object):
175185

176186
cpdef vector rotate(self, double angle=0., vector axis=None):
177187
cdef vector u
178-
if axis == None:
188+
if axis is None:
179189
u = vector(0,0,1)
180190
else:
181191
u = axis.hat
182192
cdef double c = cos(angle)
183193
cdef double s = sin(angle)
184194
cdef double t = 1.0 - c
185-
cdef double x = u.x
186-
cdef double y = u.y
187-
cdef double z = u.z
195+
cdef double x = u._x
196+
cdef double y = u._y
197+
cdef double z = u._z
188198
cdef double m11 = t*x*x+c
189199
cdef double m12 = t*x*y-z*s
190200
cdef double m13 = t*x*z+y*s
@@ -194,25 +204,25 @@ cdef class vector(object):
194204
cdef double m31 = t*x*z-y*s
195205
cdef double m32 = t*y*z+x*s
196206
cdef double m33 = t*z*z+c
197-
cdef double sx = self.x
198-
cdef double sy = self.y
199-
cdef double sz = self.z
207+
cdef double sx = self._x
208+
cdef double sy = self._y
209+
cdef double sz = self._z
200210
return vector( (m11*sx + m12*sy + m13*sz),
201211
(m21*sx + m22*sy + m23*sz),
202212
(m31*sx + m32*sy + m33*sz) )
203213

204214
cpdef rotate_in_place(self, double angle=0., vector axis=None):
205215
cdef vector u
206-
if axis == None:
216+
if axis is None:
207217
u = vector(0,0,1)
208218
else:
209219
u = axis.hat
210220
cdef double c = cos(angle)
211221
cdef double s = sin(angle)
212222
cdef double t = 1.0 - c
213-
cdef double x = u.x
214-
cdef double y = u.y
215-
cdef double z = u.z
223+
cdef double x = u._x
224+
cdef double y = u._y
225+
cdef double z = u._z
216226
cdef double m11 = t*x*x+c
217227
cdef double m12 = t*x*y-z*s
218228
cdef double m13 = t*x*z+y*s
@@ -222,9 +232,9 @@ cdef class vector(object):
222232
cdef double m31 = t*x*z-y*s
223233
cdef double m32 = t*y*z+x*s
224234
cdef double m33 = t*z*z+c
225-
cdef double sx = self.x
226-
cdef double sy = self.y
227-
cdef double sz = self.z
235+
cdef double sx = self._x
236+
cdef double sy = self._y
237+
cdef double sz = self._z
228238
self._x = m11*sx + m12*sy + m13*sz
229239
self._y = m21*sx + m22*sy + m23*sz
230240
self._z = m31*sx + m32*sy + m33*sz
@@ -234,9 +244,9 @@ cpdef object_rotate(vector objaxis, vector objup, double angle, vector axis):
234244
cdef double c = cos(angle)
235245
cdef double s = sin(angle)
236246
cdef double t = 1.0 - c
237-
cdef double x = u.x
238-
cdef double y = u.y
239-
cdef double z = u.z
247+
cdef double x = u._x
248+
cdef double y = u._y
249+
cdef double z = u._z
240250
cdef double m11 = t*x*x+c
241251
cdef double m12 = t*x*y-z*s
242252
cdef double m13 = t*x*z+y*s
@@ -246,15 +256,15 @@ cpdef object_rotate(vector objaxis, vector objup, double angle, vector axis):
246256
cdef double m31 = t*x*z-y*s
247257
cdef double m32 = t*y*z+x*s
248258
cdef double m33 = t*z*z+c
249-
cdef double sx = objaxis.x
250-
cdef double sy = objaxis.y
251-
cdef double sz = objaxis.z
259+
cdef double sx = objaxis._x
260+
cdef double sy = objaxis._y
261+
cdef double sz = objaxis._z
252262
objaxis._x = m11*sx + m12*sy + m13*sz # avoid creating a new vector object
253263
objaxis._y = m21*sx + m22*sy + m23*sz
254264
objaxis._z = m31*sx + m32*sy + m33*sz
255-
sx = objup.x
256-
sy = objup.y
257-
sz = objup.z
265+
sx = objup._x
266+
sy = objup._y
267+
sz = objup._z
258268
objup._x = m11*sx + m12*sy + m13*sz
259269
objup._y = m21*sx + m22*sy + m23*sz
260270
objup._z = m31*sx + m32*sy + m33*sz
@@ -294,7 +304,7 @@ cpdef vector rotate(vector A, double angle = 0., vector axis = None):
294304
cpdef vector adjust_up(vector oldaxis, vector newaxis, vector up, vector save_oldaxis): # adjust up when axis is changed
295305
cdef double angle
296306
cdef vector rotaxis
297-
if abs(newaxis.x) + abs(newaxis.y) + abs(newaxis.z) == 0:
307+
if abs(newaxis._x) + abs(newaxis._y) + abs(newaxis._z) == 0:
298308
# If axis has changed to <0,0,0>, must save the old axis to restore later
299309
if save_oldaxis is None: save_oldaxis = oldaxis
300310
return save_oldaxis
@@ -321,7 +331,7 @@ cpdef vector adjust_up(vector oldaxis, vector newaxis, vector up, vector save_ol
321331
cpdef vector adjust_axis(vector oldup, vector newup, vector axis, vector save_oldup): # adjust axis when up is changed
322332
cdef double angle
323333
cdef vector rotaxis
324-
if abs(newup.x) + abs(newup.y) + abs(newup.z) == 0:
334+
if abs(newup._x) + abs(newup._y) + abs(newup._z) == 0:
325335
# If up will be set to <0,0,0>, must save the old up to restore later
326336
if save_oldup is None: save_oldup = oldup
327337
return save_oldup

vpython/vector.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ def __rmul__(self, other):
7777
except:
7878
raise TypeError('a vector can only be multiplied by a scalar')
7979

80+
def __eq__(self,other):
81+
if type(self) is vector and type(other) is vector:
82+
return self.equals(other)
83+
return False
84+
85+
def __ne__(self,other):
86+
if type(self) is vector and type(other) is vector:
87+
return not self.equals(other)
88+
return True
89+
8090
@property
8191
def x(self):
8292
return self._x
@@ -171,16 +181,16 @@ def diff_angle(self, other):
171181
return acos(a)
172182

173183
def rotate(self, angle=0., axis=None):
174-
if axis == None:
184+
if axis is None:
175185
u = vector(0,0,1)
176186
else:
177187
u = axis.hat
178188
c = cos(angle)
179189
s = sin(angle)
180190
t = 1.0 - c
181-
x = u.x
182-
y = u.y
183-
z = u.z
191+
x = u._x
192+
y = u._y
193+
z = u._z
184194
m11 = t*x*x+c
185195
m12 = t*x*y-z*s
186196
m13 = t*x*z+y*s
@@ -190,24 +200,24 @@ def rotate(self, angle=0., axis=None):
190200
m31 = t*x*z-y*s
191201
m32 = t*y*z+x*s
192202
m33 = t*z*z+c
193-
sx = self.x
194-
sy = self.y
195-
sz = self.z
203+
sx = self._x
204+
sy = self._y
205+
sz = self._z
196206
return vector( (m11*sx + m12*sy + m13*sz),
197207
(m21*sx + m22*sy + m23*sz),
198208
(m31*sx + m32*sy + m33*sz) )
199209

200210
def rotate_in_place(self, angle=0., axis=None):
201-
if axis == None:
211+
if axis is None:
202212
u = vector(0,0,1)
203213
else:
204214
u = axis.hat
205215
c = cos(angle)
206216
s = sin(angle)
207217
t = 1.0 - c
208-
x = u.x
209-
y = u.y
210-
z = u.z
218+
x = u._x
219+
y = u._y
220+
z = u._z
211221
m11 = t*x*x+c
212222
m12 = t*x*y-z*s
213223
m13 = t*x*z+y*s
@@ -217,9 +227,9 @@ def rotate_in_place(self, angle=0., axis=None):
217227
m31 = t*x*z-y*s
218228
m32 = t*y*z+x*s
219229
m33 = t*z*z+c
220-
sx = self.x
221-
sy = self.y
222-
sz = self.z
230+
sx = self._x
231+
sy = self._y
232+
sz = self._z
223233
self._x = m11*sx + m12*sy + m13*sz
224234
self._y = m21*sx + m22*sy + m23*sz
225235
self._z = m31*sx + m32*sy + m33*sz
@@ -229,9 +239,9 @@ def object_rotate(objaxis, objup, angle, axis):
229239
c = cos(angle)
230240
s = sin(angle)
231241
t = 1.0 - c
232-
x = u.x
233-
y = u.y
234-
z = u.z
242+
x = u._x
243+
y = u._y
244+
z = u._z
235245
m11 = t*x*x+c
236246
m12 = t*x*y-z*s
237247
m13 = t*x*z+y*s
@@ -241,15 +251,15 @@ def object_rotate(objaxis, objup, angle, axis):
241251
m31 = t*x*z-y*s
242252
m32 = t*y*z+x*s
243253
m33 = t*z*z+c
244-
sx = objaxis.x
245-
sy = objaxis.y
246-
sz = objaxis.z
254+
sx = objaxis._x
255+
sy = objaxis._y
256+
sz = objaxis._z
247257
objaxis._x = m11*sx + m12*sy + m13*sz # avoid creating a new vector object
248258
objaxis._y = m21*sx + m22*sy + m23*sz
249259
objaxis._z = m31*sx + m32*sy + m33*sz
250-
sx = objup.x
251-
sy = objup.y
252-
sz = objup.z
260+
sx = objup._x
261+
sy = objup._y
262+
sz = objup._z
253263
objup._x = m11*sx + m12*sy + m13*sz
254264
objup._y = m21*sx + m22*sy + m23*sz
255265
objup._z = m31*sx + m32*sy + m33*sz
@@ -285,7 +295,7 @@ def rotate(A, angle=0., axis = None):
285295
return A.rotate(angle,axis)
286296

287297
def adjust_up(oldaxis, newaxis, up, save_oldaxis): # adjust up when axis is changed
288-
if abs(newaxis.x) + abs(newaxis.y) + abs(newaxis.z) == 0:
298+
if abs(newaxis._x) + abs(newaxis._y) + abs(newaxis._z) == 0:
289299
# If axis has changed to <0,0,0>, must save the old axis to restore later
290300
if save_oldaxis is None: save_oldaxis = oldaxis
291301
return save_oldaxis
@@ -310,7 +320,7 @@ def adjust_up(oldaxis, newaxis, up, save_oldaxis): # adjust up when axis is chan
310320
return save_oldaxis
311321

312322
def adjust_axis(oldup, newup, axis, save_oldup): # adjust axis when up is changed
313-
if abs(newup.x) + abs(newup.y) + abs(newup.z) == 0:
323+
if abs(newup._x) + abs(newup._y) + abs(newup._z) == 0:
314324
# If up will be set to <0,0,0>, must save the old up to restore later
315325
if save_oldup is None: save_oldup = oldup
316326
return save_oldup

0 commit comments

Comments
 (0)