Skip to content

Commit d166d81

Browse files
committed
second step
1 parent 1bfd39f commit d166d81

File tree

4 files changed

+53
-16
lines changed

4 files changed

+53
-16
lines changed

onnx_array_api/npx/npx_functions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Any, Optional, Tuple, Union
22

33
import numpy as np
4-
from onnx import FunctionProto, ModelProto, NodeProto
4+
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
5+
from onnx.helper import np_dtype_to_tensor_dtype
56
from onnx.numpy_helper import from_array
67

78
from .npx_constants import FUNCTION_DOMAIN
@@ -188,7 +189,16 @@ def astype(
188189
raise TypeError(
189190
f"dtype is an attribute, it cannot be a Variable of type {type(dtype)}."
190191
)
191-
return var(a, op="Cast", to=dtype)
192+
try:
193+
to = np_dtype_to_tensor_dtype(dtype)
194+
except KeyError:
195+
if dtype is int:
196+
to = TensorProto.INT64
197+
elif dtype is float:
198+
to = TensorProto.float64
199+
else:
200+
raise ValueError(f"Unable to guess tensor type from {dtype}.")
201+
return var(a, op="Cast", to=to)
192202

193203

194204
@npxapi_inline

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,13 @@ def make_node(
235235
f"Cannot create a node Identity for {len(inputs)} input(s) and "
236236
f"{len(outputs)} output(s)."
237237
)
238-
node = make_node(op, inputs, outputs, domain=domain, **new_kwargs)
238+
try:
239+
node = make_node(op, inputs, outputs, domain=domain, **new_kwargs)
240+
except TypeError as e:
241+
raise TypeError(
242+
f"Unable to create node {op!r}, with inputs={inputs}, "
243+
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
244+
) from e
239245
for p in protos:
240246
node.attribute.append(p)
241247
if attribute_protos is not None:

onnx_array_api/npx/npx_jit_eager.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,30 +142,50 @@ def to_jit(self, *values, **kwargs):
142142
types and the expected number of dimensions.
143143
"""
144144
annotations = self.f.__annotations__
145-
annot_values = list(annotations.values())
146-
constraints = {
147-
f"x{i}": v.tensor_type_dims
148-
for i, v in enumerate(values)
149-
if isinstance(v, (EagerTensor, JitTensor))
150-
and (i >= len(annot_values) or issubclass(annot_values[i], TensorType))
151-
}
145+
if len(annotations) > 0:
146+
names = list(annotations.keys())
147+
annot_values = list(annotations.values())
148+
constraints = {}
149+
new_kwargs = {}
150+
for i, (v, iname) in enumerate(zip(values, names)):
151+
if isinstance(v, (EagerTensor, JitTensor)) and (
152+
i >= len(annot_values) or issubclass(annot_values[i], TensorType)
153+
):
154+
constraints[iname] = v.tensor_type_dims
155+
else:
156+
new_kwargs[iname] = v
157+
else:
158+
names = [f"x{i}" for i in range(len(values))]
159+
new_kwargs = {}
160+
constraints = {
161+
iname: v.tensor_type_dims
162+
for i, (v, iname) in enumerate(zip(values, names))
163+
if isinstance(v, (EagerTensor, JitTensor))
164+
}
152165

153166
if self.output_types is not None:
154167
constraints.update(self.output_types)
155168

156-
inputs = [Input(f"x{i}") for i in range(len(values)) if f"x{i}" in constraints]
157-
if len(inputs) < len(values):
169+
inputs = [
170+
Input(iname) for iname, v in zip(names, values) if iname in constraints
171+
]
172+
names = [i.name for i in inputs]
173+
if len(new_kwargs) > 0:
158174
# An attribute is not named in the numpy API
159175
# but is the ONNX definition.
160-
raise NotImplementedError()
176+
if len(kwargs) == 0:
177+
kwargs = new_kwargs
178+
else:
179+
kwargs = kwargs.copy()
180+
kwargs.update(kwargs)
181+
161182
var = self.f(*inputs, **kwargs)
162183

163184
onx = var.to_onnx(
164185
constraints=constraints,
165186
target_opsets=self.target_opsets,
166187
ir_version=self.ir_version,
167188
)
168-
names = [f"x{i}" for i in range(len(values))]
169189
exe = self.tensor_class.create_function(names, onx)
170190
return onx, exe
171191

@@ -223,7 +243,7 @@ def jit_call(self, *values, **kwargs):
223243
raise RuntimeError(
224244
f"Unable to run function for key={key!r}, "
225245
f"types={[type(x) for x in values]}, "
226-
f"onnx={self.onxs[key]}."
246+
f"kwargs={kwargs}, onnx={self.onxs[key]}."
227247
) from e
228248
return res
229249

onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def run(self, *inputs: List["NumpyTensor"]) -> List["NumpyTensor"]:
3737
"""
3838
if len(inputs) != len(self.input_names):
3939
raise ValueError(
40-
f"Expected {len(self.input_names)} inputs but got " f"len(inputs)."
40+
f"Expected {len(self.input_names)} inputs but got {len(inputs)}, "
41+
f"self.input_names={self.input_names}, inputs={inputs}."
4142
)
4243
feeds = {}
4344
for name, inp in zip(self.input_names, inputs):

0 commit comments

Comments
 (0)