Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,43 @@ def to_python(self, obj_id: int):
return ctypes.cast(obj_id, ctypes.py_object).value


class Float32sr(typeclass):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No documentation?


def __init__(self):
self.type = numpy.float32
self.bytes = 4
self.dtype = self
self.typename = "float"
self.stochastically_rounded = True

def to_json(self):
return 'float32sr'

@staticmethod
def from_json(json_obj, context=None):
from dace.symbolic import pystr_to_symbolic # must be included!
return float32sr()

@property
def ctype(self):
return "dace::float32sr"

@property
def ctype_unaligned(self):
return self.ctype

def as_ctypes(self):
""" Returns the ctypes version of the typeclass. """
return _FFI_CTYPES[self.type]

def as_numpy_dtype(self):
return numpy.dtype(self.type)

@property
def base_type(self):
return self


class compiletime:
"""
Data descriptor type hint signalling that argument evaluation is
Expand Down Expand Up @@ -1183,6 +1220,7 @@ def isconstant(var):
complex128 = typeclass(numpy.complex128)
string = stringtype()
MPI_Request = opaque('MPI_Request')
float32sr = Float32sr()


@undefined_safe_enum
Expand Down
46 changes: 41 additions & 5 deletions dace/frontend/python/replacements/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,36 @@ def _is_op_boolean(op: str):
return False


def _handle_casting_for_stochastically_rounded_types(input_datatypes, restype, cast_types):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring? data type hints??

float_to_sr = {
dace.float32: dace.float32sr,
}

for i, dtype in enumerate(input_datatypes):
if hasattr(dtype, "stochastically_rounded"):
if cast_types[i] and dtype.type == eval(cast_types[i]).type:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random eval, not good!!!

cast_types[i] = None

# check if stoc rounded inputs
stochastically_rounded = True
for i, dtype in enumerate(input_datatypes):
if not hasattr(dtype, "stochastically_rounded"):
stochastically_rounded = False
break

if stochastically_rounded:
# make the result SR
if restype in float_to_sr:
restype = float_to_sr[restype]

# cast the intermediate types
for i, dtype in enumerate(cast_types):
if dtype in float_to_sr:
cast_types[i] = float_to_sr[dtype]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is cast_types not also a return value here, just to clarify that it gets mutated?


return restype


def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic]],
operator: str = None) -> Tuple[Union[List[dtypes.typeclass], dtypes.typeclass, str], ...]:

Expand Down Expand Up @@ -144,12 +174,16 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
raise TypeError("Type {t} of argument {a} is not supported".format(t=type(arg), a=arg))

complex_types = {dtypes.complex64, dtypes.complex128, np.complex64, np.complex128}
float_types = {dtypes.float16, dtypes.float32, dtypes.float64, np.float16, np.float32, np.float64}
float_types = {dace.float16, dace.float32, dace.float32sr, dace.float64, np.float16, np.float32, np.float64}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, you should use the module if it exists.

signed_types = {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, np.int8, np.int16, np.int32, np.int64}
# unsigned_types = {np.uint8, np.uint16, np.uint32, np.uint64}

coarse_types = []
for dtype in datatypes:
for dt in datatypes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the rename?

dtype = dt
if hasattr(dt, "srtype"): # unwrap stochastically rounded vars
dtype = dt.srtype

if dtype in complex_types:
coarse_types.append(3) # complex
elif dtype in float_types:
Expand Down Expand Up @@ -336,18 +370,20 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
else: # Operators with 3 or more arguments
restype = np_result_type(dtypes_for_result)
coarse_result_type = None
if result_type in complex_types:
if restype in complex_types:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is/was result_type, why did this change? Was this a bug?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug introduced by this refactor: c909f8b#diff-b20441227a628465d1c6ca8819915b77cf036f8fc1a8e26fbceb8a930fde9d1dR353

The origional code, formerly in the replacements.py file, looked like

    else:  # Operators with 3 or more arguments
        result_type = _np_result_type(dtypes_for_result)
        coarse_result_type = None
        if result_type in complex_types:
            coarse_result_type = 3  # complex
        elif result_type in float_types:
            coarse_result_type = 2  # float
        elif result_type in signed_types:
            coarse_result_type = 1  # signed integer, bool
        else:
            coarse_result_type = 0  # unsigned integer
        for i, t in enumerate(coarse_types):
            if t != coarse_result_type:
                casting[i] = _cast_str(result_type)

I believe this var was renamed to restype to not conflict when the "result_type" function

coarse_result_type = 3 # complex
elif result_type in float_types:
elif restype in float_types:
coarse_result_type = 2 # float
elif result_type in signed_types:
elif restype in signed_types:
coarse_result_type = 1 # signed integer, bool
else:
coarse_result_type = 0 # unsigned integer
for i, t in enumerate(coarse_types):
if t != coarse_result_type:
casting[i] = cast_str(restype)

restype = _handle_casting_for_stochastically_rounded_types(datatypes, restype, casting)

return restype, casting


Expand Down
10 changes: 8 additions & 2 deletions dace/libraries/blas/nodes/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
(desc_x, stride_x), (desc_y, stride_y), desc_res, sz = node.validate(parent_sdfg, parent_state)
dtype = desc_x.dtype.base_type
veclen = desc_x.dtype.veclen
cast = "(float *)" if dtype == dace.float32sr else ""

try:
func, _, _ = blas_helpers.cublas_type_metadata(dtype)
Expand All @@ -82,7 +83,8 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
n = n or node.n or sz
if veclen != 1:
n /= veclen
code = f"_result = cblas_{func}({n}, _x, {stride_x}, _y, {stride_y});"

code = f"_result = cblas_{func}({n}, {cast} _x, {stride_x}, {cast} _y, {stride_y});"
# The return type is scalar in cblas_?dot signature
tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors, {'_result': dtype},
Expand Down Expand Up @@ -204,7 +206,11 @@ def validate(self, sdfg, state):
if desc_x.dtype != desc_y.dtype:
raise TypeError(f"Data types of input operands must be equal: {desc_x.dtype}, {desc_y.dtype}")
if desc_x.dtype.base_type != desc_res.dtype.base_type:
raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}")
input_types = (desc_x.dtype.base_type, desc_res.dtype.base_type)
if dace.float32sr in input_types and dace.float32sr in input_types:
pass # ignore mismatch if it is stochastically rounded
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not an expert on SR types, so my question here is: can we safely do that, or rather - why can we safely do that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made a typo here -- thank you for catching that. Line 210 should be:
if dace.float32sr in input_types and dace.float32 in input_types:. I will also rename input_types to arg_types and add more comments

What this does and why it's needed: when using stocastic rounding, a legitimate (i.e not a bug) mismatch between the input and output arguments may arise where one argument is a float32sr and the other is a float32 (round-to-nearest). The underlying data type is the same so this should not cause the validation to fail. This is a check for that edge-case

else:
raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}")

# Squeeze input memlets
squeezed1 = copy.deepcopy(in_memlets[0].subset)
Expand Down
4 changes: 3 additions & 1 deletion dace/libraries/blas/nodes/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def expansion(node, state, sdfg):
node.validate(sdfg, state)
(_, adesc, _, _, _, _), (_, bdesc, _, _, _, _), _ = _get_matmul_operands(node, state, sdfg)
dtype = adesc.dtype.base_type

func = to_blastype(dtype.type).lower() + 'gemm'
alpha = f'{dtype.ctype}({node.alpha})'
beta = f'{dtype.ctype}({node.beta})'
Expand All @@ -178,6 +179,7 @@ def expansion(node, state, sdfg):
check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)

opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, dtype.ctype, func)
opt['cast'] = "(float *)" if dtype == dace.float32sr else ""

# Adaptations for BLAS API
opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans'
Expand All @@ -193,7 +195,7 @@ def expansion(node, state, sdfg):
opt['beta'] = '&__beta'

code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, "
"{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, "
"{M}, {N}, {K}, {alpha},{cast} {x}, {lda}, {cast} {y}, {ldb}, {beta}, "
"_c, {ldc});").format_map(opt)

tasklet = dace.sdfg.nodes.Tasklet(
Expand Down
4 changes: 3 additions & 1 deletion dace/libraries/blas/nodes/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
name_out="_y")
dtype_a = outer_array_a.dtype.type
dtype = outer_array_x.dtype.base_type
cast = "(float *)" if dtype == dace.float32sr else ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that C++ can handle those casts if necessary


veclen = outer_array_x.dtype.veclen
alpha = f'{dtype.ctype}({node.alpha})'
beta = f'{dtype.ctype}({node.beta})'
Expand Down Expand Up @@ -280,7 +282,7 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
alpha = '&__alpha'
beta = '&__beta'

code += f"""cblas_{func}({layout}, {trans}, {m}, {n}, {alpha}, _A, {lda},
code += f"""cblas_{func}({layout}, {trans}, {m}, {n}, {alpha}, {cast} _A, {lda},
_x, {strides_x[0]}, {beta}, _y, {strides_y[0]});"""

tasklet = dace.sdfg.nodes.Tasklet(node.name,
Expand Down
3 changes: 2 additions & 1 deletion dace/libraries/lapack/nodes/potrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ class ExpandPotrfOpenBLAS(ExpandTransformation):
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
(desc_x, stride_x, rows_x, cols_x), desc_result = node.validate(parent_sdfg, parent_state)
dtype = desc_x.dtype.base_type
cast = "(float *)" if dtype == dace.float32sr else ""
lapack_dtype = blas_helpers.to_blastype(dtype.type).lower()
if desc_x.dtype.veclen > 1:
raise (NotImplementedError)

n = n or node.n
uplo = "'L'" if node.lower else "'U'"
code = f"_res = LAPACKE_{lapack_dtype}potrf(LAPACK_ROW_MAJOR, {uplo}, {rows_x}, _xin, {stride_x});"
code = f"_res = LAPACKE_{lapack_dtype}potrf(LAPACK_ROW_MAJOR, {uplo}, {rows_x}, {cast} _xin, {stride_x});"
tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
Expand Down
3 changes: 2 additions & 1 deletion dace/libraries/linalg/nodes/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _make_sdfg(node, parent_state, parent_sdfg, implementation):

inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state)
dtype = inp_desc.dtype
cast = "(float *)" if dtype == dace.float32sr else ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting a scalar to a pointer!!!


sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

Expand All @@ -35,7 +36,7 @@ def _make_sdfg(node, parent_state, parent_sdfg, implementation):
_, me, mx = state.add_mapped_tasklet('_uzero_',
dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]),
dict(_inp=Memlet.simple('_b', '__i, __j')),
'_out = (__i < __j) ? 0 : _inp;',
f'_out = (__i < __j) ? {cast}(0) : _inp;',
dict(_out=Memlet.simple('_b', '__i, __j')),
language=dace.dtypes.Language.CPP,
external_edges=True)
Expand Down
1 change: 1 addition & 0 deletions dace/runtime/include/dace/dace.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "perf/reporting.h"
#include "comm.h"
#include "serialization.h"
#include "stocastic_rounding.h"

#if defined(__CUDACC__) || defined(__HIPCC__)
#include "cuda/cudacommon.cuh"
Expand Down
60 changes: 60 additions & 0 deletions dace/runtime/include/dace/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#define __DACE_REDUCTION_H

#include <cstdint>
#include <dace/stocastic_rounding.h>

#include "types.h"
#include "vector.h"
Expand Down Expand Up @@ -121,6 +122,40 @@ namespace dace {
}
};

template <>
struct wcr_custom<dace::float32sr> {
template <typename WCR>
static DACE_HDFI dace::float32sr reduce_atomic(WCR wcr, dace::float32sr *ptr, const dace::float32sr& value) {
#ifdef DACE_USE_GPU_ATOMICS
// Stochastic rounding version of atomic float reduction
int *iptr = reinterpret_cast<int *>(ptr);
int old = *iptr, assumed;
do {
assumed = old;
float old_val = __int_as_float(assumed);
float new_val = static_cast<float>(wcr(static_cast<dace::float32sr>(old_val), value));
old = atomicCAS(iptr, assumed, __float_as_int(new_val));
} while (assumed != old);
return static_cast<dace::float32sr>(__int_as_float(old));
#else
dace::float32sr old;
#pragma omp critical
{
old = *ptr;
*ptr = wcr(old, value);
}
return old;
#endif
}

template <typename WCR>
static DACE_HDFI dace::float32sr reduce(WCR wcr, dace::float32sr *ptr, const dace::float32sr& value) {
dace::float32sr old = *ptr;
*ptr = wcr(old, value);
return old;
}
};

template <>
struct wcr_custom<double> {
template <typename WCR>
Expand Down Expand Up @@ -313,6 +348,31 @@ namespace dace {
DACE_HDFI float operator()(const float &a, const float &b) const { return ::max(a, b); }
};


template <>
struct _wcr_fixed<ReductionType::Min, dace::float32sr> {

static DACE_HDFI dace::float32sr reduce_atomic(dace::float32sr *ptr, const dace::float32sr& value) {
return wcr_custom<dace::float32sr>::reduce_atomic(
_wcr_fixed<ReductionType::Min, dace::float32sr>(), ptr, value);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you just use the float32 atomicMin/Max here?

}


DACE_HDFI dace::float32sr operator()(const dace::float32sr &a, const dace::float32sr &b) const { return ::min(a, b); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that work? should they not be cast to float explicitly first?
I think that breaks AMD GPU builds

};

template <>
struct _wcr_fixed<ReductionType::Max, dace::float32sr> {

static DACE_HDFI dace::float32sr reduce_atomic(dace::float32sr *ptr, const dace::float32sr& value) {
return wcr_custom<dace::float32sr>::reduce_atomic(
_wcr_fixed<ReductionType::Max, dace::float32sr>(), ptr, value);
}

DACE_HDFI dace::float32sr operator()(const dace::float32sr &a, const dace::float32sr &b) const { return ::max(a, b); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

};


template <>
struct _wcr_fixed<ReductionType::Min, double> {

Expand Down
Loading