Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions dace/frontend/fortran/ast_desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,37 @@ def _eval_selected_real_kind(p: int, r: int) -> int:
return 2


def _selected_real_kind(args: Tuple[Union[Int_Literal_Constant, Actual_Arg_Spec]],
alias_map: SPEC_TABLE) -> int:
p = np.int32(0)
r = np.int32(0)
radix_err = NotImplementedError("Cannot handle RADIX argument for SELECTED_REAL_KIND.")
if len(args) > 2:
raise radix_err

nextarg = 'p'
for arg in args:
argname = nextarg
argval = arg
if isinstance(arg, Actual_Arg_Spec):
argname, argval = arg.children
assert(isinstance(argname, Name))
argname = argname.string.lower()

if argname == 'p':
p = _const_eval_basic_type(argval, alias_map)
nextarg = 'r'
elif argname == 'r':
r = _const_eval_basic_type(argval, alias_map)
nextarg = 'p'
else:
assert(argname == 'radix')
raise radix_err

assert isinstance(p, np.int32) and isinstance(r, np.int32)
return np.int32(_eval_selected_real_kind(p, r))


def _const_eval_int(expr: Base, alias_map: SPEC_TABLE) -> Optional[int]:
if isinstance(expr, Name):
scope_spec = find_scope_spec(expr)
Expand All @@ -479,11 +510,7 @@ def _const_eval_int(expr: Base, alias_map: SPEC_TABLE) -> Optional[int]:
if args:
args = args.children
if intr.string == 'SELECTED_REAL_KIND':
assert len(args) == 2
p, r = args
p, r = _const_eval_int(p, alias_map), _const_eval_int(r, alias_map)
assert p is not None and r is not None
return _eval_selected_real_kind(p, r)
return _selected_real_kind(args, alias_map)
elif intr.string == 'SELECTED_INT_KIND':
assert len(args) == 1
p, = args
Expand Down Expand Up @@ -691,10 +718,7 @@ def _const_eval_basic_type(expr: Base, alias_map: SPEC_TABLE) -> Optional[NUMPY_
if all(isinstance(a, (np.float32, np.float64)) for a in avals):
return INTR_FNS[intr.string](*avals)
elif intr.string == 'SELECTED_REAL_KIND':
p, r = args
p, r = _const_eval_basic_type(p, alias_map), _const_eval_basic_type(r, alias_map)
assert isinstance(p, np.int32) and isinstance(r, np.int32)
return np.int32(_eval_selected_real_kind(p, r))
return _selected_real_kind(args, alias_map)
elif intr.string == 'SELECTED_INT_KIND':
p, = args
p = _const_eval_basic_type(p, alias_map)
Expand Down
64 changes: 64 additions & 0 deletions tests/fortran/ast_desugaring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,70 @@ def test_exploit_locally_constant_pointers():
SourceCodeBuilder().add_file(got).check_with_gfortran()


def test_exploit_locally_constant_selected_real_kind():
sources, main = SourceCodeBuilder().add_file("""
subroutine main
implicit none
integer :: dummy
integer :: sp, dp, ep, qp, xp, yp, zp
! integer :: ap, bp

sp = selected_real_kind(6, 37)
dummy = sp
dp = selected_real_kind(12, 307)
dummy = dp
ep = selected_real_kind(12)
dummy = ep
! TODO: Technically we should return an error code if the precision
! is not available:
! https://gcc.gnu.org/onlinedocs/gfortran/SELECTED_005fREAL_005fKIND.html
qp = selected_real_kind(30)
dummy = qp

! Test the optional argument mechanism:
xp = selected_real_kind(r=37)
dummy = xp
yp = selected_real_kind(p=1)
dummy = yp
zp = selected_real_kind(p=12,r=1)
dummy = zp

! TODO: We do not support any of the following:
! ap = selected_real_kind(6, 37, 3)
! dummy = ap
! bp = selected_real_kind(radix=5)
! dummy = bp
end subroutine main
""").check_with_gfortran().get()
ast = parse_and_improve(sources)
ast = exploit_locally_constant_variables(ast)

got = ast.tofortran()
want = """
SUBROUTINE main
IMPLICIT NONE
INTEGER :: dummy
INTEGER :: sp, dp, ep, qp, xp, yp, zp
sp = SELECTED_REAL_KIND(6, 37)
dummy = 4
dp = SELECTED_REAL_KIND(12, 307)
dummy = 8
ep = SELECTED_REAL_KIND(12)
dummy = 8
qp = SELECTED_REAL_KIND(30)
dummy = 8
xp = SELECTED_REAL_KIND(r = 37)
dummy = 4
yp = SELECTED_REAL_KIND(p = 1)
dummy = 2
zp = SELECTED_REAL_KIND(p = 12, r = 1)
dummy = 8
END SUBROUTINE main
""".strip()
assert got == want
SourceCodeBuilder().add_file(got).check_with_gfortran()


def test_consolidate_global_data():
sources, main = SourceCodeBuilder().add_file("""
module lib
Expand Down