Skip to content

Commit c9576df

Browse files
creavintbennun
andauthored
Add dace::float32sr type to DaCe (#2148)
This new type enables DaCe users to perform calculations with stochastic rounding in single precision. This change is validated with unit tests. --------- Co-authored-by: Tal Ben-Nun <[email protected]>
1 parent 846d791 commit c9576df

File tree

11 files changed

+938
-11
lines changed

11 files changed

+938
-11
lines changed

dace/dtypes.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,43 @@ def to_python(self, obj_id: int):
817817
return ctypes.cast(obj_id, ctypes.py_object).value
818818

819819

820+
class Float32sr(typeclass):
821+
822+
def __init__(self):
823+
self.type = numpy.float32
824+
self.bytes = 4
825+
self.dtype = self
826+
self.typename = "float"
827+
self.stochastically_rounded = True
828+
829+
def to_json(self):
830+
return 'float32sr'
831+
832+
@staticmethod
833+
def from_json(json_obj, context=None):
834+
from dace.symbolic import pystr_to_symbolic # must be included!
835+
return float32sr()
836+
837+
@property
838+
def ctype(self):
839+
return "dace::float32sr"
840+
841+
@property
842+
def ctype_unaligned(self):
843+
return self.ctype
844+
845+
def as_ctypes(self):
846+
""" Returns the ctypes version of the typeclass. """
847+
return _FFI_CTYPES[self.type]
848+
849+
def as_numpy_dtype(self):
850+
return numpy.dtype(self.type)
851+
852+
@property
853+
def base_type(self):
854+
return self
855+
856+
820857
class compiletime:
821858
"""
822859
Data descriptor type hint signalling that argument evaluation is
@@ -1175,6 +1212,7 @@ def isconstant(var):
11751212
complex128 = typeclass(numpy.complex128)
11761213
string = stringtype()
11771214
MPI_Request = opaque('MPI_Request')
1215+
float32sr = Float32sr()
11781216

11791217

11801218
@undefined_safe_enum

dace/frontend/python/replacements/operators.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,36 @@ def _is_op_boolean(op: str):
113113
return False
114114

115115

116+
def _handle_casting_for_stochastically_rounded_types(input_datatypes, restype, cast_types):
117+
float_to_sr = {
118+
dace.float32: dace.float32sr,
119+
}
120+
121+
for i, dtype in enumerate(input_datatypes):
122+
if hasattr(dtype, "stochastically_rounded"):
123+
if cast_types[i] and dtype.type == eval(cast_types[i]).type:
124+
cast_types[i] = None
125+
126+
# check if stoc rounded inputs
127+
stochastically_rounded = True
128+
for i, dtype in enumerate(input_datatypes):
129+
if not hasattr(dtype, "stochastically_rounded"):
130+
stochastically_rounded = False
131+
break
132+
133+
if stochastically_rounded:
134+
# make the result SR
135+
if restype in float_to_sr:
136+
restype = float_to_sr[restype]
137+
138+
# cast the intermediate types
139+
for i, dtype in enumerate(cast_types):
140+
if dtype in float_to_sr:
141+
cast_types[i] = float_to_sr[dtype]
142+
143+
return restype
144+
145+
116146
def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic]],
117147
operator: str = None) -> Tuple[Union[List[dtypes.typeclass], dtypes.typeclass, str], ...]:
118148

@@ -144,12 +174,16 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
144174
raise TypeError("Type {t} of argument {a} is not supported".format(t=type(arg), a=arg))
145175

146176
complex_types = {dtypes.complex64, dtypes.complex128, np.complex64, np.complex128}
147-
float_types = {dtypes.float16, dtypes.float32, dtypes.float64, np.float16, np.float32, np.float64}
177+
float_types = {dace.float16, dace.float32, dace.float32sr, dace.float64, np.float16, np.float32, np.float64}
148178
signed_types = {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, np.int8, np.int16, np.int32, np.int64}
149179
# unsigned_types = {np.uint8, np.uint16, np.uint32, np.uint64}
150180

151181
coarse_types = []
152-
for dtype in datatypes:
182+
for dt in datatypes:
183+
dtype = dt
184+
if hasattr(dt, "srtype"): # unwrap stochastically rounded vars
185+
dtype = dt.srtype
186+
153187
if dtype in complex_types:
154188
coarse_types.append(3) # complex
155189
elif dtype in float_types:
@@ -336,18 +370,20 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
336370
else: # Operators with 3 or more arguments
337371
restype = np_result_type(dtypes_for_result)
338372
coarse_result_type = None
339-
if result_type in complex_types:
373+
if restype in complex_types:
340374
coarse_result_type = 3 # complex
341-
elif result_type in float_types:
375+
elif restype in float_types:
342376
coarse_result_type = 2 # float
343-
elif result_type in signed_types:
377+
elif restype in signed_types:
344378
coarse_result_type = 1 # signed integer, bool
345379
else:
346380
coarse_result_type = 0 # unsigned integer
347381
for i, t in enumerate(coarse_types):
348382
if t != coarse_result_type:
349383
casting[i] = cast_str(restype)
350384

385+
restype = _handle_casting_for_stochastically_rounded_types(datatypes, restype, casting)
386+
351387
return restype, casting
352388

353389

dace/libraries/blas/nodes/dot.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
7070
(desc_x, stride_x), (desc_y, stride_y), desc_res, sz = node.validate(parent_sdfg, parent_state)
7171
dtype = desc_x.dtype.base_type
7272
veclen = desc_x.dtype.veclen
73+
cast = "(float *)" if dtype == dace.float32sr else ""
7374

7475
try:
7576
func, _, _ = blas_helpers.cublas_type_metadata(dtype)
@@ -82,7 +83,8 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
8283
n = n or node.n or sz
8384
if veclen != 1:
8485
n /= veclen
85-
code = f"_result = cblas_{func}({n}, _x, {stride_x}, _y, {stride_y});"
86+
87+
code = f"_result = cblas_{func}({n}, {cast} _x, {stride_x}, {cast} _y, {stride_y});"
8688
# The return type is scalar in cblas_?dot signature
8789
tasklet = dace.sdfg.nodes.Tasklet(node.name,
8890
node.in_connectors, {'_result': dtype},
@@ -204,7 +206,16 @@ def validate(self, sdfg, state):
204206
if desc_x.dtype != desc_y.dtype:
205207
raise TypeError(f"Data types of input operands must be equal: {desc_x.dtype}, {desc_y.dtype}")
206208
if desc_x.dtype.base_type != desc_res.dtype.base_type:
207-
raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}")
209+
arg_types = (desc_x.dtype.base_type, desc_res.dtype.base_type)
210+
if dace.float32 in arg_types and dace.float32sr in arg_types:
211+
"""
212+
When using stocastic rounding, a legitimate (i.e not a bug) mismatch between the input and output
213+
arguments may arise where one argument is a float32sr and the other is a float32 (round-to-nearest).
214+
The underlying data type is the same so this should not cause the validation to fail.
215+
"""
216+
pass
217+
else:
218+
raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}")
208219

209220
# Squeeze input memlets
210221
squeezed1 = copy.deepcopy(in_memlets[0].subset)

dace/libraries/blas/nodes/gemm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def expansion(node, state, sdfg):
163163
node.validate(sdfg, state)
164164
(_, adesc, _, _, _, _), (_, bdesc, _, _, _, _), _ = _get_matmul_operands(node, state, sdfg)
165165
dtype = adesc.dtype.base_type
166+
166167
func = to_blastype(dtype.type).lower() + 'gemm'
167168
alpha = f'{dtype.ctype}({node.alpha})'
168169
beta = f'{dtype.ctype}({node.beta})'
@@ -178,6 +179,7 @@ def expansion(node, state, sdfg):
178179
check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)
179180

180181
opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, dtype.ctype, func)
182+
opt['cast'] = "(float *)" if dtype == dace.float32sr else ""
181183

182184
# Adaptations for BLAS API
183185
opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans'
@@ -193,7 +195,7 @@ def expansion(node, state, sdfg):
193195
opt['beta'] = '&__beta'
194196

195197
code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, "
196-
"{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, "
198+
"{M}, {N}, {K}, {alpha},{cast} {x}, {lda}, {cast} {y}, {ldb}, {beta}, "
197199
"_c, {ldc});").format_map(opt)
198200

199201
tasklet = dace.sdfg.nodes.Tasklet(

dace/libraries/blas/nodes/gemv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
232232
name_out="_y")
233233
dtype_a = outer_array_a.dtype.type
234234
dtype = outer_array_x.dtype.base_type
235+
cast = "(float *)" if dtype == dace.float32sr else ""
236+
235237
veclen = outer_array_x.dtype.veclen
236238
alpha = f'{dtype.ctype}({node.alpha})'
237239
beta = f'{dtype.ctype}({node.beta})'
@@ -280,7 +282,7 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
280282
alpha = '&__alpha'
281283
beta = '&__beta'
282284

283-
code += f"""cblas_{func}({layout}, {trans}, {m}, {n}, {alpha}, _A, {lda},
285+
code += f"""cblas_{func}({layout}, {trans}, {m}, {n}, {alpha}, {cast} _A, {lda},
284286
_x, {strides_x[0]}, {beta}, _y, {strides_y[0]});"""
285287

286288
tasklet = dace.sdfg.nodes.Tasklet(node.name,

dace/libraries/lapack/nodes/potrf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ class ExpandPotrfOpenBLAS(ExpandTransformation):
3232
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
3333
(desc_x, stride_x, rows_x, cols_x), desc_result = node.validate(parent_sdfg, parent_state)
3434
dtype = desc_x.dtype.base_type
35+
cast = "(float *)" if dtype == dace.float32sr else ""
3536
lapack_dtype = blas_helpers.to_blastype(dtype.type).lower()
3637
if desc_x.dtype.veclen > 1:
3738
raise (NotImplementedError)
3839

3940
n = n or node.n
4041
uplo = "'L'" if node.lower else "'U'"
41-
code = f"_res = LAPACKE_{lapack_dtype}potrf(LAPACK_ROW_MAJOR, {uplo}, {rows_x}, _xin, {stride_x});"
42+
code = f"_res = LAPACKE_{lapack_dtype}potrf(LAPACK_ROW_MAJOR, {uplo}, {rows_x}, {cast} _xin, {stride_x});"
4243
tasklet = dace.sdfg.nodes.Tasklet(node.name,
4344
node.in_connectors,
4445
node.out_connectors,

dace/libraries/linalg/nodes/cholesky.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def _make_sdfg(node, parent_state, parent_sdfg, implementation):
1616

1717
inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state)
1818
dtype = inp_desc.dtype
19+
cast = "(float *)" if dtype == dace.float32sr else ""
1920
storage = inp_desc.storage
2021

2122
sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))
@@ -36,7 +37,7 @@ def _make_sdfg(node, parent_state, parent_sdfg, implementation):
3637
_, me, mx = state.add_mapped_tasklet('_uzero_',
3738
dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]),
3839
dict(_inp=Memlet.simple('_b', '__i, __j')),
39-
'_out = (__i < __j) ? 0 : _inp;',
40+
f'_out = (__i < __j) ? {cast}(0) : _inp;',
4041
dict(_out=Memlet.simple('_b', '__i, __j')),
4142
language=dace.dtypes.Language.CPP,
4243
external_edges=True)

dace/runtime/include/dace/dace.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "perf/reporting.h"
2525
#include "comm.h"
2626
#include "serialization.h"
27+
#include "stocastic_rounding.h"
2728

2829
#if defined(__CUDACC__) || defined(__HIPCC__)
2930
#include "cuda/cudacommon.cuh"

dace/runtime/include/dace/reduction.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#define __DACE_REDUCTION_H
44

55
#include <cstdint>
6+
#include <dace/stocastic_rounding.h>
67

78
#include "types.h"
89
#include "vector.h"
@@ -121,6 +122,40 @@ namespace dace {
121122
}
122123
};
123124

125+
template <>
126+
struct wcr_custom<dace::float32sr> {
127+
template <typename WCR>
128+
static DACE_HDFI dace::float32sr reduce_atomic(WCR wcr, dace::float32sr *ptr, const dace::float32sr& value) {
129+
#ifdef DACE_USE_GPU_ATOMICS
130+
// Stochastic rounding version of atomic float reduction
131+
int *iptr = reinterpret_cast<int *>(ptr);
132+
int old = *iptr, assumed;
133+
do {
134+
assumed = old;
135+
float old_val = __int_as_float(assumed);
136+
float new_val = static_cast<float>(wcr(static_cast<dace::float32sr>(old_val), value));
137+
old = atomicCAS(iptr, assumed, __float_as_int(new_val));
138+
} while (assumed != old);
139+
return static_cast<dace::float32sr>(__int_as_float(old));
140+
#else
141+
dace::float32sr old;
142+
#pragma omp critical
143+
{
144+
old = *ptr;
145+
*ptr = wcr(old, value);
146+
}
147+
return old;
148+
#endif
149+
}
150+
151+
template <typename WCR>
152+
static DACE_HDFI dace::float32sr reduce(WCR wcr, dace::float32sr *ptr, const dace::float32sr& value) {
153+
dace::float32sr old = *ptr;
154+
*ptr = wcr(old, value);
155+
return old;
156+
}
157+
};
158+
124159
template <>
125160
struct wcr_custom<double> {
126161
template <typename WCR>
@@ -313,6 +348,31 @@ namespace dace {
313348
DACE_HDFI float operator()(const float &a, const float &b) const { return ::max(a, b); }
314349
};
315350

351+
352+
template <>
353+
struct _wcr_fixed<ReductionType::Min, dace::float32sr> {
354+
355+
static DACE_HDFI dace::float32sr reduce_atomic(dace::float32sr *ptr, const dace::float32sr& value) {
356+
return wcr_custom<dace::float32sr>::reduce_atomic(
357+
_wcr_fixed<ReductionType::Min, dace::float32sr>(), ptr, value);
358+
}
359+
360+
361+
DACE_HDFI dace::float32sr operator()(const dace::float32sr &a, const dace::float32sr &b) const { return ::min(a, b); }
362+
};
363+
364+
template <>
365+
struct _wcr_fixed<ReductionType::Max, dace::float32sr> {
366+
367+
static DACE_HDFI dace::float32sr reduce_atomic(dace::float32sr *ptr, const dace::float32sr& value) {
368+
return wcr_custom<dace::float32sr>::reduce_atomic(
369+
_wcr_fixed<ReductionType::Max, dace::float32sr>(), ptr, value);
370+
}
371+
372+
DACE_HDFI dace::float32sr operator()(const dace::float32sr &a, const dace::float32sr &b) const { return ::max(a, b); }
373+
};
374+
375+
316376
template <>
317377
struct _wcr_fixed<ReductionType::Min, double> {
318378

0 commit comments

Comments
 (0)