@@ -113,6 +113,36 @@ def _is_op_boolean(op: str):
113113 return False
114114
115115
116+ def _handle_casting_for_stochastically_rounded_types (input_datatypes , restype , cast_types ):
117+ float_to_sr = {
118+ dace .float32 : dace .float32sr ,
119+ }
120+
121+ for i , dtype in enumerate (input_datatypes ):
122+ if hasattr (dtype , "stochastically_rounded" ):
123+ if cast_types [i ] and dtype .type == eval (cast_types [i ]).type :
124+ cast_types [i ] = None
125+
126+ # check if stoc rounded inputs
127+ stochastically_rounded = True
128+ for i , dtype in enumerate (input_datatypes ):
129+ if not hasattr (dtype , "stochastically_rounded" ):
130+ stochastically_rounded = False
131+ break
132+
133+ if stochastically_rounded :
134+ # make the result SR
135+ if restype in float_to_sr :
136+ restype = float_to_sr [restype ]
137+
138+ # cast the intermediate types
139+ for i , dtype in enumerate (cast_types ):
140+ if dtype in float_to_sr :
141+ cast_types [i ] = float_to_sr [dtype ]
142+
143+ return restype
144+
145+
116146def result_type (arguments : Sequence [Union [str , Number , symbolic .symbol , sp .Basic ]],
117147 operator : str = None ) -> Tuple [Union [List [dtypes .typeclass ], dtypes .typeclass , str ], ...]:
118148
@@ -144,12 +174,16 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
144174 raise TypeError ("Type {t} of argument {a} is not supported" .format (t = type (arg ), a = arg ))
145175
146176 complex_types = {dtypes .complex64 , dtypes .complex128 , np .complex64 , np .complex128 }
147- float_types = {dtypes .float16 , dtypes .float32 , dtypes .float64 , np .float16 , np .float32 , np .float64 }
177+ float_types = {dace .float16 , dace .float32 , dace . float32sr , dace .float64 , np .float16 , np .float32 , np .float64 }
148178 signed_types = {dtypes .int8 , dtypes .int16 , dtypes .int32 , dtypes .int64 , np .int8 , np .int16 , np .int32 , np .int64 }
149179 # unsigned_types = {np.uint8, np.uint16, np.uint32, np.uint64}
150180
151181 coarse_types = []
152- for dtype in datatypes :
182+ for dt in datatypes :
183+ dtype = dt
184+ if hasattr (dt , "srtype" ): # unwrap stochastically rounded vars
185+ dtype = dt .srtype
186+
153187 if dtype in complex_types :
154188 coarse_types .append (3 ) # complex
155189 elif dtype in float_types :
@@ -336,18 +370,20 @@ def result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basic
336370 else : # Operators with 3 or more arguments
337371 restype = np_result_type (dtypes_for_result )
338372 coarse_result_type = None
339- if result_type in complex_types :
373+ if restype in complex_types :
340374 coarse_result_type = 3 # complex
341- elif result_type in float_types :
375+ elif restype in float_types :
342376 coarse_result_type = 2 # float
343- elif result_type in signed_types :
377+ elif restype in signed_types :
344378 coarse_result_type = 1 # signed integer, bool
345379 else :
346380 coarse_result_type = 0 # unsigned integer
347381 for i , t in enumerate (coarse_types ):
348382 if t != coarse_result_type :
349383 casting [i ] = cast_str (restype )
350384
385+ restype = _handle_casting_for_stochastically_rounded_types (datatypes , restype , casting )
386+
351387 return restype , casting
352388
353389
0 commit comments