Skip to content
This repository was archived by the owner on Dec 14, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
8ef30b8
Replacement of GCNConv to a placeholder Op, shape inference
lamyiowce Apr 18, 2022
3289036
Proper file structure, ONNX placeholder op expands into an extension …
lamyiowce May 13, 2022
c92d721
Produces an SDFG with dummy computation, restores module data, shape …
lamyiowce May 16, 2022
50a81c7
fix
lamyiowce May 16, 2022
b124f0e
move module_id to attribute
lamyiowce May 16, 2022
ddddcb2
missing import
lamyiowce May 16, 2022
c1273fa
Formatting
lamyiowce May 16, 2022
d29c269
nicer handling of class name
lamyiowce May 16, 2022
1c90b54
Shape inference in registration and attempts at proper parameter hand…
lamyiowce May 23, 2022
b5497e7
Add parameters as input nodes in the graph.
lamyiowce Jun 5, 2022
cbbc3c5
Add parameters as input nodes in the graph.
lamyiowce Jun 5, 2022
ff7ef22
Weights are correctly added to replaced modules. Fix an annoying bug …
lamyiowce Jun 6, 2022
3f0be8a
implementation of GCNConv and a test for it
lamyiowce Jun 12, 2022
cd319d0
Bug fixes in adding weights as inputs
lamyiowce Jun 22, 2022
7236060
Replacement of GCNConv to a placeholder Op, shape inference
lamyiowce Apr 18, 2022
f139f12
Proper file structure, ONNX placeholder op expands into an extension …
lamyiowce May 13, 2022
2db639d
Produces an SDFG with dummy computation, restores module data, shape …
lamyiowce May 16, 2022
eb9c7cd
fix
lamyiowce May 16, 2022
1ab4e3c
move module_id to attribute
lamyiowce May 16, 2022
ab10e61
missing import
lamyiowce May 16, 2022
f8f621e
Formatting
lamyiowce May 16, 2022
7f24524
nicer handling of class name
lamyiowce May 16, 2022
0990840
Shape inference in registration and attempts at proper parameter hand…
lamyiowce May 23, 2022
d7815f5
Add parameters as input nodes in the graph.
lamyiowce Jun 5, 2022
3ed8994
Add parameters as input nodes in the graph.
lamyiowce Jun 5, 2022
16cd590
Weights are correctly added to replaced modules. Fix an annoying bug …
lamyiowce Jun 6, 2022
ecb4ee5
implementation of GCNConv and a test for it
lamyiowce Jun 12, 2022
214bc2b
Bug fixes in adding weights as inputs
lamyiowce Jun 22, 2022
5470395
Add support for sparse matrix input.
lamyiowce Jun 25, 2022
e67aa03
Merge remote-tracking branch 'origin/master'
lamyiowce Jun 25, 2022
8ff27ea
Add test for multiple layers for GCNConv
lamyiowce Jul 7, 2022
9bf3aa1
Add benhmarking script
lamyiowce Jul 18, 2022
376ac61
cleaner benchamrk script, fix in dtype in gcnconv implementation
lamyiowce Jul 19, 2022
5f4056c
Pre-normalize the inputs
lamyiowce Jul 20, 2022
b1db0df
add table and fix test
lamyiowce Jul 20, 2022
793f7d7
Buggy GAT (but compiles!)
lamyiowce Jul 20, 2022
2ba661f
GAT benchmark
lamyiowce Jul 20, 2022
645dffc
Add threadblock dyanmic map schdeuling and get rid of unnecessary Dto…
lamyiowce Jul 22, 2022
02e4c21
Add named inputs/outputs in replacement registration, minor code changes
lamyiowce Aug 13, 2022
f76914f
Merging upstream branch
lamyiowce Aug 14, 2022
d66d0c1
Code cleanup and fixes after merge.
lamyiowce Aug 14, 2022
3a02075
Add pytorch_geometric add setup.py
orausch Aug 15, 2022
8f57b42
Add script to run benchmarks, write performance results to a file, fi…
lamyiowce Nov 15, 2022
cbe69d8
Merge branch 'spcl:master' into master
lamyiowce Nov 15, 2022
896af0c
Final changes: improve benchmarking output, final script for ault.
lamyiowce Jan 23, 2023
36e7e00
Remove a print.
lamyiowce Jan 23, 2023
4e46d1b
Code cleanup.
lamyiowce Jan 23, 2023
bd945c6
Type annotation for shape fn.
lamyiowce Jan 23, 2023
bebdb61
Use output dtype from replacement info, use dace dtyeps instead of st…
lamyiowce Jan 23, 2023
94adde9
Revert formatting changes to symbolic_shape_infer.py
lamyiowce Jan 24, 2023
f121ed0
Oopsie, revert breaking change in symbolic_shape_infer.py
lamyiowce Jan 24, 2023
cf8dbb8
Revert all changes in symbolic_shape_infer.py. Now the SymbolicShapeI…
lamyiowce Jan 24, 2023
30c8f7a
Fix wrong dependency in setup.py
lamyiowce Jan 24, 2023
4f30c49
Clean up implementations and make tests smaller so they run quicker.
lamyiowce Jan 24, 2023
23e7310
Remove the slurm script
lamyiowce Jan 24, 2023
374e3a1
Slightly clean up shape_inference.py
lamyiowce Jan 24, 2023
c2e30b1
Fix typo in comment
lamyiowce Jan 26, 2023
d8cb794
Handle thread block dynamic scheduling in a more elegant way.
lamyiowce Feb 10, 2023
49331f6
Fix typos in shape inference!
lamyiowce Feb 10, 2023
d0bf864
Reformatting
lamyiowce Feb 12, 2023
dcfdc08
Merge branch 'spcl:master' into master
lamyiowce Apr 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion daceml/onnx/converters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re
import logging
import re
from typing import Union

import dace
import onnx
import torch
from dace import dtypes as dt
from dace.dtypes import typeclass
from onnx.numpy_helper import to_array
Expand Down Expand Up @@ -209,3 +211,20 @@ def clean_onnx_name(name: str) -> str:
name = f"ONNX_{name}"
return name.replace(".", "DOT").replace(":", "COLON").replace(
"/", "SLASH").replace("-", "DASH")


TYPECLASS_TO_TORCH_DTYPE = {
dace.bool_: torch.bool,
dace.int8: torch.int8,
dace.int16: torch.int16,
dace.int32: torch.int32,
dace.int64: torch.int64,
dace.uint8: torch.uint8,
dace.float16: torch.float16,
dace.float32: torch.float32,
dace.float64: torch.float64,
dace.complex64: torch.complex64,
dace.complex128: torch.complex128,
}

TORCH_DTYPE_TO_TYPECLASS = {v: k for k, v in TYPECLASS_TO_TORCH_DTYPE.items()}
3 changes: 3 additions & 0 deletions daceml/onnx/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .onnx_op import *
from .replacement_entries import *
from .replacement import *

# we don't want to export ONNXOp
del globals()["ONNXOp"]
287 changes: 287 additions & 0 deletions daceml/onnx/nodes/replacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import logging
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, Tuple, Type, Mapping

import dace
import torch
from dace import SDFG, nodes
from dace.properties import Property
from dace.transformation.transformation import ExpandTransformation
from onnx.onnx_pb import NodeProto

from daceml.onnx.converters import clean_onnx_name, TORCH_DTYPE_TO_TYPECLASS, typeclass_to_onnx_str, \
TYPECLASS_TO_TORCH_DTYPE
from daceml.onnx.forward_implementation_abc import ONNXForward
from daceml.onnx.nodes.node_codegen import expand_node
from daceml.onnx.nodes.onnx_op import (ONNXOp, _get_attr_docstring,
_get_connector_docstring,
_get_typecons_docstring)
from daceml.onnx.schema import ONNXParameter, ONNXParameterType, ONNXSchema, ONNXTypeConstraint
from daceml.onnx.shape_inference.symbolic_shape_infer import SymbolicShapeInference

log = logging.getLogger(__name__)

ShapeFnType = Callable[..., Tuple[int, ...]]


@dataclass
class ReplacementInfo:
module_name: str
onnx_op: Type[nodes.Node]
infer_shape: Callable[[SymbolicShapeInference, NodeProto], None]
shape_fn_from_module: Callable[[torch.nn.Module], ShapeFnType]
output_dtype: torch.dtype


MODULES_TO_REPLACE: Dict[str, ReplacementInfo] = {}


def is_replaceable(name: str) -> bool:
return name in MODULES_TO_REPLACE


def get_replaced_onnx_op(name: str) -> Type[nodes.Node]:
if name not in MODULES_TO_REPLACE:
raise ValueError(f'No replacement module for {name}.')
return MODULES_TO_REPLACE[name].onnx_op


def make_schema_dict(name, inputs: Mapping[str, dace.typeclass],
outputs: Mapping[str, dace.typeclass]):
intersection = [name for name in inputs if name in outputs]
assert len(
intersection
) == 0, f"Same keys for inputs and outputs not allowed: {intersection}"

schema_dict = {
'name': name,
'attributes': {},
'doc': f'Placeholder for {name}',
'domain': '',
'since_version': 1,
'type': 'ONNXSchema'
}

def make_type_info_helper(type_mapping: Mapping[str, dace.typeclass]):
data_type_list = []
type_constraints = {}
for i, (name, t) in enumerate(type_mapping.items()):
# For some reason dace.float32 gets converted to string as 'float',
# not 'float32' which is not understood by ONNX.
if t is dace.float32:
t = 'float32'
else:
t = typeclass_to_onnx_str(t)
type_name = f'{name}_T'
entry = {
'description': '',
'homogeneous': True,
'name': f'{name}',
'param_type': 'Single',
'type': 'ONNXParameter',
'type_str': type_name
}
data_type_list.append(entry)

type_constraints[type_name] = {
'type': 'ONNXTypeConstraint',
'type_str': type_name,
'types': [t]
}
return data_type_list, type_constraints

inputs_info, inputs_type_constraints = make_type_info_helper(inputs)
outputs_info, outputs_type_constraints = make_type_info_helper(outputs)

schema_dict.update({
'inputs': inputs_info,
'outputs': outputs_info,
'type_constraints': {
**inputs_type_constraints,
**outputs_type_constraints
},
})
return schema_dict


def onnx_type_info_from_torch_params(
params: Iterable[Tuple[str, torch.nn.Parameter]]):
onnx_params = []
onnx_type_constraints = {}
for name, p in params:
name = clean_onnx_name(name)
type_name = name + '_T'
onnx_params.append(
ONNXParameter.from_json({
'description': '',
'homogeneous': True,
'name': name,
'param_type': 'Single',
'type': 'ONNXParameter',
'type_str': type_name
}))
onnx_type_constraints[type_name] = ONNXTypeConstraint.from_json({
'type':
'ONNXTypeConstraint',
'type_str':
type_name,
'types': [TORCH_DTYPE_TO_TYPECLASS[p.dtype].to_string()],
})
return onnx_params, onnx_type_constraints


# Generating an ONNX Library node.
def generate_onnx_op_placeholder(schema):
attrs = {}

def __init__(self,
name,
module,
prefix,
*args,
location=None,
**op_attributes):
# Add information about module parameters to the schema.
onnx_params, onnx_type_constraints = onnx_type_info_from_torch_params(
module.named_parameters())
self.schema = deepcopy(self.schema)
self.schema.inputs += onnx_params
self.schema.type_constraints.update(onnx_type_constraints)
# TODO: Get input/output spec from module?

super(ONNXOp, self).__init__(
name,
location=location,
# add required parameters as in/out connectors, without types for now
inputs={
inp.name
for inp in self.schema.inputs
if inp.param_type == ONNXParameterType.Single
},
outputs={
out.name
for out in self.schema.outputs
if out.param_type == ONNXParameterType.Single
})

self.backward_implementation = None
self.module = module
self.prefix = prefix

if len(args) > 0:
raise TypeError(
f"__init__() takes 2 positional arguments but {2 + len(args)} were given"
)

if len(op_attributes) > 0:
raise TypeError(
f"__init__() takes no keyword arguments but following were given: {op_attributes}"
)

# TODO: the docstrings for params are missing, but are they needed?
input_connector_docstrings = "\n".join(
_get_connector_docstring(param) for param in schema.inputs)
output_connector_docstrings = "\n".join(
_get_connector_docstring(param) for param in schema.outputs)

cls_name = schema.name

# the first line of the init docstring contains the signature of the method. This will be picked up by sphinx and
# means that the generated sphinx docs have a proper signature, and not just *args, **kwargs.
init_docstring = "__init__(name, *, {})\n".format(
", ".join(attr.name if attr.required else attr.name + "=" +
repr(attr.default_value)
for _, attr in schema.attributes.items()))
init_docstring += ":param name: the name of the node.\n" + "\n".join(
_get_attr_docstring(attr) for _, attr in schema.attributes.items())

docstring = "\n" + schema.doc
type_docstrings = "\n".join(
_get_typecons_docstring(cons)
for _, cons in schema.type_constraints.items())
docstring += "\n\n"
docstring += ":Node Inputs:" + input_connector_docstrings
docstring += "\n\n"
docstring += ":Node Outputs:" + output_connector_docstrings
docstring += "\n\n"
docstring += ":Type Constraints:" + type_docstrings

# TODO: Check if the documentation makes any sense. Maybe copy from GCNConv or maybe not
attrs['__doc__'] = docstring + "\n"
attrs['schema'] = schema
attrs['__init__'] = __init__
attrs['module'] = Property(dtype=torch.nn.Module,
desc='Replaced module',
allow_none=False)
attrs['prefix'] = Property(dtype=str,
desc='Prefix for the module.',
allow_none=False)

cls = type(cls_name, (ONNXOp, ), attrs)
cls = dace.library.node(cls)
cls.__init__.__doc__ = "\n" + init_docstring

for impl, args in ONNXForward.extensions().items():
if "op" in args and args["op"] == schema.name:

class Expansion(ExpandTransformation):
environments = []
forward_impl: ONNXForward = impl

@classmethod
def expansion(cls, node, state, sdfg, **kwargs):
# validate
node.validate(sdfg, state)

if cls.forward_impl.forward_can_be_applied(
node, state, sdfg):
result = cls.forward_impl.forward(
node, state, sdfg, **kwargs)
if hasattr(cls.forward_impl, "environments"):
cls.environments.extend(
cls.forward_impl.environments)
return result
else:
log.warning(
'No expansion for library node "{}". '
'Reason: forward_can_be_applied returned False'.
format(node.label))
result = expand_node(node, state, sdfg)
if not isinstance(result, SDFG):
# When we return an SDFG the environments will be determined recursively by codegen.
cls.environments = map(
dace.library.get_environment,
result.environments)
return result

implementation_name = args["name"]
cls.register_implementation(implementation_name, Expansion)

globals()[cls_name] = cls

return cls


# Registration of replacement.
def register_replacement(
module_name: str, inputs: Mapping[str, dace.typeclass],
outputs: Mapping[str, dace.typeclass],
shape_infer: Callable[[SymbolicShapeInference, 'NodeProto'], None],
shape_fn_from_module: Callable[[torch.nn.Module], ShapeFnType]):
if len(outputs) > 1:
raise NotImplementedError(
"Replacing nodes with more than 1 output is not supported.")

output_dtype = next(iter(outputs.values()))

schema_dict = make_schema_dict(module_name, inputs, outputs)
schema = ONNXSchema.from_json(schema_dict)
onnx_op = generate_onnx_op_placeholder(schema)
replacement_info = ReplacementInfo(
module_name=module_name,
infer_shape=shape_infer,
shape_fn_from_module=shape_fn_from_module,
onnx_op=onnx_op,
output_dtype=TYPECLASS_TO_TORCH_DTYPE[output_dtype])
MODULES_TO_REPLACE[module_name] = replacement_info
Loading