diff --git a/dace/frontend/fortran/ast_desugaring.py b/dace/frontend/fortran/ast_desugaring.py index 661872bd2b..41350ba7dc 100644 --- a/dace/frontend/fortran/ast_desugaring.py +++ b/dace/frontend/fortran/ast_desugaring.py @@ -463,6 +463,37 @@ def _eval_selected_real_kind(p: int, r: int) -> int: return 2 +def _selected_real_kind(args: Tuple[Union[Int_Literal_Constant, Actual_Arg_Spec]], + alias_map: SPEC_TABLE) -> int: + p = np.int32(0) + r = np.int32(0) + radix_err = NotImplementedError("Cannot handle RADIX argument for SELECTED_REAL_KIND.") + if len(args) > 2: + raise radix_err + + nextarg = 'p' + for arg in args: + argname = nextarg + argval = arg + if isinstance(arg, Actual_Arg_Spec): + argname, argval = arg.children + assert(isinstance(argname, Name)) + argname = argname.string.lower() + + if argname == 'p': + p = _const_eval_basic_type(argval, alias_map) + nextarg = 'r' + elif argname == 'r': + r = _const_eval_basic_type(argval, alias_map) + nextarg = 'p' + else: + assert(argname == 'radix') + raise radix_err + + assert isinstance(p, np.int32) and isinstance(r, np.int32) + return np.int32(_eval_selected_real_kind(p, r)) + + def _const_eval_int(expr: Base, alias_map: SPEC_TABLE) -> Optional[int]: if isinstance(expr, Name): scope_spec = find_scope_spec(expr) @@ -479,11 +510,7 @@ def _const_eval_int(expr: Base, alias_map: SPEC_TABLE) -> Optional[int]: if args: args = args.children if intr.string == 'SELECTED_REAL_KIND': - assert len(args) == 2 - p, r = args - p, r = _const_eval_int(p, alias_map), _const_eval_int(r, alias_map) - assert p is not None and r is not None - return _eval_selected_real_kind(p, r) + return _selected_real_kind(args, alias_map) elif intr.string == 'SELECTED_INT_KIND': assert len(args) == 1 p, = args @@ -691,10 +718,7 @@ def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_ if all(isinstance(a, (np.float32, np.float64)) for a in avals): return INTR_FNS[intr.string](*avals) elif intr.string == 'SELECTED_REAL_KIND': - p, r = args - p, r = _const_eval_basic_type(p, alias_map), _const_eval_basic_type(r, alias_map) - assert isinstance(p, np.int32) and isinstance(r, np.int32) - return np.int32(_eval_selected_real_kind(p, r)) + return _selected_real_kind(args, alias_map) elif intr.string == 'SELECTED_INT_KIND': p, = args p = _const_eval_basic_type(p, alias_map) diff --git a/tests/fortran/ast_desugaring_test.py b/tests/fortran/ast_desugaring_test.py index b9dff85461..310c6acfa5 100644 --- a/tests/fortran/ast_desugaring_test.py +++ b/tests/fortran/ast_desugaring_test.py @@ -2541,6 +2541,70 @@ def test_exploit_locally_constant_pointers(): SourceCodeBuilder().add_file(got).check_with_gfortran() +def test_exploit_locally_constant_selected_real_kind(): + sources, main = SourceCodeBuilder().add_file(""" +subroutine main + implicit none + integer :: dummy + integer :: sp, dp, ep, qp, xp, yp, zp + ! integer :: ap, bp + + sp = selected_real_kind(6, 37) + dummy = sp + dp = selected_real_kind(12, 307) + dummy = dp + ep = selected_real_kind(12) + dummy = ep + ! TODO: Technically we should return an error code if the precision + ! is not available: + ! https://gcc.gnu.org/onlinedocs/gfortran/SELECTED_005fREAL_005fKIND.html + qp = selected_real_kind(30) + dummy = qp + + ! Test the optional argument mechanism: + xp = selected_real_kind(r=37) + dummy = xp + yp = selected_real_kind(p=1) + dummy = yp + zp = selected_real_kind(p=12,r=1) + dummy = zp + + ! TODO: We do not support any of the following: + ! ap = selected_real_kind(6, 37, 3) + ! dummy = ap + ! bp = selected_real_kind(radix=5) + ! dummy = bp +end subroutine main +""").check_with_gfortran().get() + ast = parse_and_improve(sources) + ast = exploit_locally_constant_variables(ast) + + got = ast.tofortran() + want = """ +SUBROUTINE main + IMPLICIT NONE + INTEGER :: dummy + INTEGER :: sp, dp, ep, qp, xp, yp, zp + sp = SELECTED_REAL_KIND(6, 37) + dummy = 4 + dp = SELECTED_REAL_KIND(12, 307) + dummy = 8 + ep = SELECTED_REAL_KIND(12) + dummy = 8 + qp = SELECTED_REAL_KIND(30) + dummy = 8 + xp = SELECTED_REAL_KIND(r = 37) + dummy = 4 + yp = SELECTED_REAL_KIND(p = 1) + dummy = 2 + zp = SELECTED_REAL_KIND(p = 12, r = 1) + dummy = 8 +END SUBROUTINE main +""".strip() + assert got == want + SourceCodeBuilder().add_file(got).check_with_gfortran() + + def test_consolidate_global_data(): sources, main = SourceCodeBuilder().add_file(""" module lib