Skip to content

Commit 88f988a

Browse files
committed
Fix a bug in dependencies. Implement support for relational operations.
1 parent 345c9d9 commit 88f988a

File tree

3 files changed

+157
-29
lines changed

3 files changed

+157
-29
lines changed

numpy/f2py/crackfortran.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,11 +2701,6 @@ def analyzevars(block):
27012701
'required' if is_required else 'optional')
27022702
if v_attr:
27032703
vars[v]['attrspec'] = v_attr
2704-
else:
2705-
# n is output or hidden argument, hence it
2706-
# will depend on all variables in d
2707-
n_deps.extend(coeffs_and_deps)
2708-
27092704
if coeffs_and_deps is not None:
27102705
# extend v dependencies with ones specified in attrspec
27112706
for v, (solver, deps) in coeffs_and_deps.items():
@@ -2716,6 +2711,8 @@ def analyzevars(block):
27162711
v_deps.extend(aa[7:-1].split(','))
27172712
if v_deps:
27182713
vars[v]['depend'] = list(set(v_deps))
2714+
if n not in v_deps:
2715+
n_deps.append(v)
27192716
elif isstring(vars[n]):
27202717
if 'charselector' in vars[n]:
27212718
if '*' in vars[n]['charselector']:

numpy/f2py/symbolic.py

Lines changed: 118 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#
1616
# TODO: support logical constants (Op.BOOLEAN)
1717
# TODO: support logical operators (.AND., ...)
18-
# TODO: support relational operators (<, >, ..., .LT., ...)
1918
# TODO: support defined operators (.MYOP., ...)
2019
#
2120
__all__ = ['Expr']
@@ -50,12 +49,43 @@ class Op(Enum):
5049
APPLY = 200
5150
INDEXING = 210
5251
CONCAT = 220
52+
RELATIONAL = 300
5353
TERMS = 1000
5454
FACTORS = 2000
5555
REF = 3000
5656
DEREF = 3001
5757

5858

59+
class RelOp(Enum):
60+
"""
61+
Used in Op.RELATIONAL expression to specify the function part.
62+
"""
63+
EQ = 1
64+
NE = 2
65+
LT = 3
66+
LE = 4
67+
GT = 5
68+
GE = 6
69+
70+
@classmethod
71+
def fromstring(cls, s, language=Language.C):
72+
if language is Language.Fortran:
73+
return {'.eq.': RelOp.EQ, '.ne.': RelOp.NE,
74+
'.lt.': RelOp.LT, '.le.': RelOp.LE,
75+
'.gt.': RelOp.GT, '.ge.': RelOp.GE}[s.lower()]
76+
return {'==': RelOp.EQ, '!=': RelOp.NE, '<': RelOp.LT,
77+
'<=': RelOp.LE, '>': RelOp.GT, '>=': RelOp.GE}[s]
78+
79+
def tostring(self, language=Language.C):
80+
if language is Language.Fortran:
81+
return {RelOp.EQ: '.eq.', RelOp.NE: '.ne.',
82+
RelOp.LT: '.lt.', RelOp.LE: '.le.',
83+
RelOp.GT: '.gt.', RelOp.GE: '.ge.'}[self]
84+
return {RelOp.EQ: '==', RelOp.NE: '!=',
85+
RelOp.LT: '<', RelOp.LE: '<=',
86+
RelOp.GT: '>', RelOp.GE: '>='}[self]
87+
88+
5989
class ArithOp(Enum):
6090
"""
6191
Used in Op.APPLY expression to specify the function part.
@@ -77,12 +107,19 @@ class Precedence(Enum):
77107
"""
78108
Used as Expr.tostring precedence argument.
79109
"""
80-
NONE = 0
81-
TUPLE = 1
82-
SUM = 2
110+
ATOM = 0
111+
POWER = 1
112+
UNARY = 2
83113
PRODUCT = 3
84-
POWER = 4
85-
ATOM = 5
114+
SUM = 4
115+
LT = 6
116+
EQ = 7
117+
LAND = 11
118+
LOR = 12
119+
TERNARY = 13
120+
ASSIGN = 14
121+
TUPLE = 15
122+
NONE = 100
86123

87124

88125
integer_types = (int,)
@@ -178,6 +215,9 @@ def __init__(self, op, data):
178215
elif op in (Op.REF, Op.DEREF):
179216
# data is Expr instance
180217
assert isinstance(data, Expr)
218+
elif op is Op.RELATIONAL:
219+
# data is (<relop>, <left>, <right>)
220+
assert isinstance(data, tuple) and len(data) == 3
181221
else:
182222
raise NotImplementedError(
183223
f'unknown op or missing sanity check: {op}')
@@ -341,19 +381,32 @@ def tostring(self, parent_precedence=Precedence.NONE,
341381
language=language)
342382
for a in self.data]
343383
if language is Language.C:
344-
return f'({cond} ? {expr1} : {expr2})'
345-
if language is Language.Python:
346-
return f'({expr1} if {cond} else {expr2})'
347-
if language is Language.Fortran:
348-
return f'merge({expr1}, {expr2}, {cond})'
349-
raise NotImplementedError(f'tostring for {self.op} and {language}')
384+
r = f'({cond} ? {expr1} : {expr2})'
385+
elif language is Language.Python:
386+
r = f'({expr1} if {cond} else {expr2})'
387+
elif language is Language.Fortran:
388+
r = f'merge({expr1}, {expr2}, {cond})'
389+
else:
390+
raise NotImplementedError(
391+
f'tostring for {self.op} and {language}')
392+
precedence = Precedence.ATOM
350393
elif self.op is Op.REF:
351-
return '&' + self.data.tostring(language=language)
394+
r = '&' + self.data.tostring(Precedence.UNARY, language=language)
395+
precedence = Precedence.UNARY
352396
elif self.op is Op.DEREF:
353-
return '*' + self.data.tostring(language=language)
397+
r = '*' + self.data.tostring(Precedence.UNARY, language=language)
398+
precedence = Precedence.UNARY
399+
elif self.op is Op.RELATIONAL:
400+
rop, left, right = self.data
401+
precedence = (Precedence.EQ if rop in (RelOp.EQ, RelOp.NE)
402+
else Precedence.LT)
403+
left = left.tostring(precedence, language=language)
404+
right = right.tostring(precedence, language=language)
405+
rop = rop.tostring(language=language)
406+
r = f'{left} {rop} {right}'
354407
else:
355408
raise NotImplementedError(f'tostring for op {self.op}')
356-
if parent_precedence.value > precedence.value:
409+
if parent_precedence.value < precedence.value:
357410
# If parent precedence is higher than operand precedence,
358411
# operand will be enclosed in parenthesis.
359412
return '(' + r + ')'
@@ -590,7 +643,11 @@ def substitute(self, symbols_map):
590643
return normalize(Expr(self.op, operands))
591644
if self.op in (Op.REF, Op.DEREF):
592645
return normalize(Expr(self.op, self.data.substitute(symbols_map)))
593-
646+
if self.op is Op.RELATIONAL:
647+
rop, left, right = self.data
648+
left = left.substitute(symbols_map)
649+
right = right.substitute(symbols_map)
650+
return normalize(Expr(self.op, (rop, left, right)))
594651
raise NotImplementedError(f'substitute method for {self.op}: {self!r}')
595652

596653
def traverse(self, visit, *args, **kwargs):
@@ -642,6 +699,11 @@ def traverse(self, visit, *args, **kwargs):
642699
elif self.op in (Op.REF, Op.DEREF):
643700
return normalize(Expr(self.op,
644701
self.data.traverse(visit, *args, **kwargs)))
702+
elif self.op is Op.RELATIONAL:
703+
rop, left, right = self.data
704+
left = left.traverse(visit, *args, **kwargs)
705+
right = right.traverse(visit, *args, **kwargs)
706+
return normalize(Expr(self.op, (rop, left, right)))
645707
raise NotImplementedError(f'traverse method for {self.op}')
646708

647709
def contains(self, other):
@@ -963,6 +1025,30 @@ def as_deref(expr):
9631025
return Expr(Op.DEREF, expr)
9641026

9651027

1028+
def as_eq(left, right):
1029+
return Expr(Op.RELATIONAL, (RelOp.EQ, left, right))
1030+
1031+
1032+
def as_ne(left, right):
1033+
return Expr(Op.RELATIONAL, (RelOp.NE, left, right))
1034+
1035+
1036+
def as_lt(left, right):
1037+
return Expr(Op.RELATIONAL, (RelOp.LT, left, right))
1038+
1039+
1040+
def as_le(left, right):
1041+
return Expr(Op.RELATIONAL, (RelOp.LE, left, right))
1042+
1043+
1044+
def as_gt(left, right):
1045+
return Expr(Op.RELATIONAL, (RelOp.GT, left, right))
1046+
1047+
1048+
def as_ge(left, right):
1049+
return Expr(Op.RELATIONAL, (RelOp.GE, left, right))
1050+
1051+
9661052
def as_terms(obj):
9671053
"""Return expression as TERMS expression.
9681054
"""
@@ -1257,23 +1343,32 @@ def restore(r):
12571343
return as_complex(*self.process(operands))
12581344
raise NotImplementedError(
12591345
f'parsing comma-separated list (context={context}): {r}')
1260-
return tuple(self.process(restore(r.split(',')), context))
12611346

12621347
# ternary operation
12631348
m = re.match(r'\A([^?]+)[?]([^:]+)[:](.+)\Z', r)
12641349
if m:
12651350
assert context == 'expr', context
12661351
oper, expr1, expr2 = restore(m.groups())
1267-
if 0:
1268-
# TODO: enable this when support for boolean
1269-
# expressions is fully implemented
1270-
oper = self.process(oper)
1271-
else:
1272-
oper = as_symbol(self.finalize_string(oper))
1352+
oper = self.process(oper)
12731353
expr1 = self.process(expr1)
12741354
expr2 = self.process(expr2)
12751355
return as_ternary(oper, expr1, expr2)
12761356

1357+
# relational expression
1358+
if self.language is Language.Fortran:
1359+
m = re.match(
1360+
r'\A(.+)\s*[.](eq|ne|lt|le|gt|ge)[.]\s*(.+)\Z', r, re.I)
1361+
else:
1362+
m = re.match(
1363+
r'\A(.+)\s*([=][=]|[!][=]|[<][=]|[<]|[>][=]|[>])\s*(.+)\Z', r)
1364+
if m:
1365+
left, rop, right = m.groups()
1366+
if self.language is Language.Fortran:
1367+
rop = '.' + rop + '.'
1368+
left, right = self.process(restore((left, right)))
1369+
rop = RelOp.fromstring(rop, language=self.language)
1370+
return Expr(Op.RELATIONAL, (rop, left, right))
1371+
12771372
# keyword argument
12781373
m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r)
12791374
if m:

numpy/f2py/tests/test_symbolic.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
as_terms, as_factors, eliminate_quotes, insert_quotes,
66
fromstring, as_expr, as_apply,
77
as_numer_denom, as_ternary, as_ref, as_deref,
8-
normalize
8+
normalize, as_eq, as_ne, as_lt, as_gt, as_le, as_ge
99
)
1010
from . import util
1111

@@ -100,6 +100,13 @@ def test_sanity(self):
100100
assert t != u
101101
assert hash(t) is not None
102102

103+
e = as_eq(x, y)
104+
f = as_lt(x, y)
105+
assert e.op == Op.RELATIONAL
106+
assert e == e
107+
assert e != f
108+
assert hash(e) is not None
109+
103110
def test_tostring_fortran(self):
104111
x = as_symbol('x')
105112
y = as_symbol('y')
@@ -142,6 +149,12 @@ def test_tostring_fortran(self):
142149
assert str(Expr(Op.INDEXING, ('f', x))) == 'f[x]'
143150

144151
assert str(as_ternary(x, y, z)) == 'merge(y, z, x)'
152+
assert str(as_eq(x, y)) == 'x .eq. y'
153+
assert str(as_ne(x, y)) == 'x .ne. y'
154+
assert str(as_lt(x, y)) == 'x .lt. y'
155+
assert str(as_le(x, y)) == 'x .le. y'
156+
assert str(as_gt(x, y)) == 'x .gt. y'
157+
assert str(as_ge(x, y)) == 'x .ge. y'
145158

146159
def test_tostring_c(self):
147160
language = Language.C
@@ -166,6 +179,12 @@ def test_tostring_c(self):
166179
language=language) == '123 + x + (x - y) / (x + y)'
167180

168181
assert as_ternary(x, y, z).tostring(language=language) == '(x ? y : z)'
182+
assert as_eq(x, y).tostring(language=language) == 'x == y'
183+
assert as_ne(x, y).tostring(language=language) == 'x != y'
184+
assert as_lt(x, y).tostring(language=language) == 'x < y'
185+
assert as_le(x, y).tostring(language=language) == 'x <= y'
186+
assert as_gt(x, y).tostring(language=language) == 'x > y'
187+
assert as_ge(x, y).tostring(language=language) == 'x >= y'
169188

170189
def test_operations(self):
171190
x = as_symbol('x')
@@ -240,6 +259,8 @@ def test_substitute(self):
240259

241260
assert as_ternary(x, y, z).substitute(
242261
{x: y + z}) == as_ternary(y + z, y, z)
262+
assert as_eq(x, y).substitute(
263+
{x: y + z}) == as_eq(y + z, y)
243264

244265
def test_fromstring(self):
245266

@@ -319,6 +340,20 @@ def test_fromstring(self):
319340
assert fromstring('*x * *y') == as_deref(x) * as_deref(y)
320341
assert fromstring('*x**y') == as_deref(x) * as_deref(y)
321342

343+
assert fromstring('x == y') == as_eq(x, y)
344+
assert fromstring('x != y') == as_ne(x, y)
345+
assert fromstring('x < y') == as_lt(x, y)
346+
assert fromstring('x > y') == as_gt(x, y)
347+
assert fromstring('x <= y') == as_le(x, y)
348+
assert fromstring('x >= y') == as_ge(x, y)
349+
350+
assert fromstring('x .eq. y', language=Language.Fortran) == as_eq(x, y)
351+
assert fromstring('x .ne. y', language=Language.Fortran) == as_ne(x, y)
352+
assert fromstring('x .lt. y', language=Language.Fortran) == as_lt(x, y)
353+
assert fromstring('x .gt. y', language=Language.Fortran) == as_gt(x, y)
354+
assert fromstring('x .le. y', language=Language.Fortran) == as_le(x, y)
355+
assert fromstring('x .ge. y', language=Language.Fortran) == as_ge(x, y)
356+
322357
def test_traverse(self):
323358
x = as_symbol('x')
324359
y = as_symbol('y')
@@ -340,6 +375,7 @@ def replace_visit(s, r=z):
340375
assert (x + y + z).traverse(replace_visit) == (2 * z + y)
341376
assert (x + f(y, x - z)).traverse(
342377
replace_visit) == (z + f(y, as_number(0)))
378+
assert as_eq(x, y).traverse(replace_visit) == as_eq(z, y)
343379

344380
# Use traverse to collect symbols, method 1
345381
function_symbols = set()

0 commit comments

Comments
 (0)