Skip to content

Commit 9c0648d

Browse files
Figure out the class hierarchy for transforms. (#32)
* Figure out the class hierarchy for transforms. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Object-based transforms are done. * So is a NumPy-based transformation. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 487a94c commit 9c0648d

File tree

18 files changed

+892
-134
lines changed

18 files changed

+892
-134
lines changed

src/vector/backends/numpy_.py

Lines changed: 128 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,20 @@
55

66
import numpy
77

8-
import vector.compute
8+
import vector.backends.object_
9+
import vector.geometry
10+
import vector.methods
11+
12+
13+
def getitem(array, where, object_type):
14+
if isinstance(where, str):
15+
return array.view(numpy.ndarray)[where]
16+
else:
17+
out = numpy.ndarray.__getitem__(array, where)
18+
if isinstance(out, numpy.void):
19+
return object_type(*out)
20+
else:
21+
return out
922

1023

1124
class AzimuthalNumpy:
@@ -21,98 +34,130 @@ class TemporalNumpy:
2134

2235

2336
class AzimuthalNumpyXY(numpy.ndarray, AzimuthalNumpy, vector.geometry.AzimuthalXY):
24-
def __iter__(self):
25-
yield self["x"].view(numpy.ndarray)
26-
yield self["y"].view(numpy.ndarray)
37+
@property
38+
def elements(self):
39+
return (self["x"], self["y"])
2740

2841
@property
2942
def x(self):
30-
return self["x"].view(numpy.ndarray)
43+
return self["x"]
3144

3245
@property
3346
def y(self):
34-
return self["y"].view(numpy.ndarray)
47+
return self["y"]
48+
49+
def __getitem__(self, where):
50+
return getitem(self, where, vector.backends.object_.AzimuthalObjectXY)
3551

3652

3753
class AzimuthalNumpyRhoPhi(
3854
numpy.ndarray, AzimuthalNumpy, vector.geometry.AzimuthalRhoPhi
3955
):
40-
def __iter__(self):
41-
yield self["rho"].view(numpy.ndarray)
42-
yield self["phi"].view(numpy.ndarray)
56+
@property
57+
def elements(self):
58+
return (self["rho"], self["phi"])
4359

4460
@property
4561
def rho(self):
46-
return self["rho"].view(numpy.ndarray)
62+
return self["rho"]
4763

4864
@property
4965
def phi(self):
50-
return self["phi"].view(numpy.ndarray)
66+
return self["phi"]
67+
68+
def __getitem__(self, where):
69+
return getitem(self, where, vector.backends.object_.AzimuthalObjectRhoPhi)
5170

5271

5372
class LongitudinalNumpyZ(
5473
numpy.ndarray, LongitudinalNumpy, vector.geometry.LongitudinalZ
5574
):
56-
def __iter__(self):
57-
yield self["z"].view(numpy.ndarray)
75+
@property
76+
def elements(self):
77+
return (self["z"],)
5878

5979
@property
6080
def z(self):
61-
return self["z"].view(numpy.ndarray)
81+
return self["z"]
82+
83+
def __getitem__(self, where):
84+
return getitem(self, where, vector.backends.object_.LongitudinalObjectZ)
6285

6386

6487
class LongitudinalNumpyTheta(
6588
numpy.ndarray, LongitudinalNumpy, vector.geometry.LongitudinalTheta
6689
):
67-
def __iter__(self):
68-
yield self["theta"].view(numpy.ndarray)
90+
@property
91+
def elements(self):
92+
return (self["theta"],)
6993

7094
@property
7195
def theta(self):
72-
return self["theta"].view(numpy.ndarray)
96+
return self["theta"]
97+
98+
def __getitem__(self, where):
99+
return getitem(self, where, vector.backends.object_.LongitudinalObjectTheta)
73100

74101

75102
class LongitudinalNumpyEta(
76103
numpy.ndarray, LongitudinalNumpy, vector.geometry.LongitudinalEta
77104
):
78-
def __iter__(self):
79-
yield self["eta"].view(numpy.ndarray)
105+
@property
106+
def elements(self):
107+
return (self["eta"],)
80108

81109
@property
82110
def eta(self):
83-
return self["eta"].view(numpy.ndarray)
111+
return self["eta"]
112+
113+
def __getitem__(self, where):
114+
return getitem(self, where, vector.backends.object_.LongitudinalObjectEta)
84115

85116

86117
class LongitudinalNumpyW(
87118
numpy.ndarray, LongitudinalNumpy, vector.geometry.LongitudinalW
88119
):
89-
def __iter__(self):
90-
yield self["w"].view(numpy.ndarray)
120+
@property
121+
def elements(self):
122+
return (self["w"],)
91123

92124
@property
93125
def w(self):
94-
return self["w"].view(numpy.ndarray)
126+
return self["w"]
127+
128+
def __getitem__(self, where):
129+
return getitem(self, where, vector.backends.object_.LongitudinalObjectW)
95130

96131

97132
class TemporalNumpyT(numpy.ndarray, TemporalNumpy, vector.geometry.TemporalT):
98-
def __iter__(self):
99-
yield self["t"].view(numpy.ndarray)
133+
@property
134+
def elements(self):
135+
return (self["t"],)
100136

101137
@property
102138
def t(self):
103-
return self["t"].view(numpy.ndarray)
139+
return self["t"]
140+
141+
def __getitem__(self, where):
142+
return getitem(self, where, vector.backends.object_.TemporalObjectT)
104143

105144

106145
class TemporalNumpyTau(numpy.ndarray, TemporalNumpy, vector.geometry.TemporalTau):
107-
def __iter__(self):
108-
yield self["tau"].view(numpy.ndarray)
146+
@property
147+
def elements(self):
148+
return (self["tau"],)
109149

110150
@property
111151
def tau(self):
112-
return self["tau"].view(numpy.ndarray)
152+
return self["tau"]
153+
154+
def __getitem__(self, where):
155+
return getitem(self, where, vector.backends.object_.TemporalObjectTau)
156+
113157

158+
class PlanarNumpy(numpy.ndarray, vector.methods.Planar):
159+
lib = numpy
114160

115-
class PlanarVectorNumpy(numpy.ndarray, vector.geometry.Planar, vector.geometry.Vector):
116161
def __new__(cls, *args, **kwargs):
117162
return numpy.array(*args, **kwargs).view(cls)
118163

@@ -131,34 +176,64 @@ def __array_finalize__(self, obj):
131176
def azimuthal(self):
132177
return self.view(self._azimuthal_type)
133178

134-
@property
135-
def x(self):
136-
return vector.compute.planar.x.dispatch(numpy, self)
137179

138-
@property
139-
def y(self):
140-
return vector.compute.planar.y.dispatch(numpy, self)
180+
class PlanarVectorNumpy(vector.geometry.PlanarVector, PlanarNumpy):
181+
def __getitem__(self, where):
182+
return getitem(self, where, vector.backends.object_.PlanarVectorObject)
141183

142-
@property
143-
def rho(self):
144-
return vector.compute.planar.rho.dispatch(numpy, self)
145184

146-
@property
147-
def phi(self):
148-
return vector.compute.planar.phi.dispatch(numpy, self)
185+
class SpatialNumpy(numpy.ndarray, vector.methods.Spatial):
186+
lib = numpy
149187

150-
@property
151-
def rho2(self):
152-
return vector.compute.planar.rho2.dispatch(numpy, self)
188+
def __new__(cls, *args, **kwargs):
189+
return numpy.array(*args, **kwargs).view(cls)
153190

154191

155-
class SpatialVectorNumpy(
156-
numpy.ndarray, vector.geometry.Spatial, vector.geometry.Vector
157-
):
158-
pass
192+
class SpatialVectorNumpy(vector.geometry.SpatialVector, SpatialNumpy):
193+
def __getitem__(self, where):
194+
return getitem(self, where, vector.backends.object_.SpatialVectorObject)
159195

160196

161-
class LorentzVectorNumpy(
162-
numpy.ndarray, vector.geometry.Lorentz, vector.geometry.Vector
163-
):
164-
pass
197+
class LorentzNumpy(numpy.ndarray, vector.methods.Lorentz):
198+
lib = numpy
199+
200+
def __new__(cls, *args, **kwargs):
201+
return numpy.array(*args, **kwargs).view(cls)
202+
203+
def __array_finalize__(self, obj):
204+
raise NotImplementedError
205+
206+
207+
class LorentzVectorNumpy(vector.geometry.LorentzVector, LorentzNumpy):
208+
def __getitem__(self, where):
209+
return getitem(self, where, vector.backends.object_.LorentzVectorObject)
210+
211+
212+
class TransformNumpy:
213+
lib = numpy
214+
215+
216+
class Transform2DNumpy(numpy.ndarray, TransformNumpy, vector.methods.Transform2D):
217+
def __new__(cls, *args, **kwargs):
218+
return numpy.array(*args, **kwargs).view(cls)
219+
220+
def __array_finalize__(self, obj):
221+
if self.dtype.names != ("xx", "xy", "yx", "yy"):
222+
raise TypeError(
223+
f"{type(self).__name__} must have a structured dtype with fields "
224+
'("xx", "xy", "yx", "yy")'
225+
)
226+
227+
def __getitem__(self, where):
228+
return getitem(self, where, vector.backends.object_.Transform2DObject)
229+
230+
@property
231+
def elements(self):
232+
return tuple(self[x] for x in ("xx", "xy", "yx", "yy"))
233+
234+
def apply(self, v):
235+
x, y = vector.methods.Transform2D.apply(self, v)
236+
out = numpy.empty(x.shape, dtype=[("x", x.dtype), ("y", y.dtype)])
237+
out["x"] = x
238+
out["y"] = y
239+
return out.view(PlanarVectorNumpy)

0 commit comments

Comments
 (0)