Skip to content

Commit 3abac6d

Browse files
authored
Merge pull request #132 from KangarooKoala/better-requires-clauses
Improve requires clause parsing
2 parents 9116a16 + dc63715 commit 3abac6d

File tree

2 files changed

+89
-33
lines changed

2 files changed

+89
-33
lines changed

cxxheaderparser/parser.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -803,13 +803,6 @@ def _parse_concept(
803803
),
804804
)
805805

806-
# fmt: off
807-
_expr_operators = {
808-
"<", ">", "|", "%", "^", "!", "*", "-", "+", "&", "=",
809-
"&&", "||", "<<"
810-
}
811-
# fmt: on
812-
813806
def _parse_requires(
814807
self,
815808
tok: LexToken,
@@ -818,38 +811,48 @@ def _parse_requires(
818811

819812
rawtoks: typing.List[LexToken] = []
820813

821-
# The easier case -- requires requires
822-
if tok.type == "requires":
823-
rawtoks.append(tok)
824-
for tt in ("(", "{"):
825-
tok = self._next_token_must_be(tt)
814+
# The expression in a requires clause must be one of the following:
815+
# 1) A primary expression
816+
# 2) A sequence of (1) joined with &&
817+
# 3) A sequence of (2) joined with ||
818+
#
819+
# In terms of validity, this is equivalent to a sequence of primary expressions
820+
# joined with && and/or ||.
821+
#
822+
# In general, a primary expression is one of the following:
823+
# 1) this
824+
# 2) a literal
825+
# 3) an identifier expression
826+
# 4) a lambda expression
827+
# 5) a fold expression
828+
# 6) a requires expression
829+
# 7) any parenthesized expression
830+
#
831+
# For simplicity, we only consider the following primary expressions:
832+
# 1) parenthesized expressions (which includes fold expressions)
833+
# 2) requires expressions
834+
# 3) identifer expressions (possibly qualified, possibly templated)
835+
while True:
836+
if tok.type == "(":
826837
rawtoks.extend(self._consume_balanced_tokens(tok))
827-
# .. and that's it?
828-
829-
# this is either a parenthesized expression or a primary clause
830-
elif tok.type == "(":
831-
rawtoks.extend(self._consume_balanced_tokens(tok))
832-
else:
833-
while True:
834-
if tok.type == "(":
838+
tok = self.lex.token()
839+
elif tok.type == "requires":
840+
rawtoks.append(tok)
841+
for tt in ("(", "{"):
842+
tok = self._next_token_must_be(tt)
835843
rawtoks.extend(self._consume_balanced_tokens(tok))
836-
else:
837-
tok = self._parse_requires_segment(tok, rawtoks)
844+
tok = self.lex.token()
845+
else:
846+
tok = self._parse_requires_segment(tok, rawtoks)
838847

839-
# If this is not an operator of some kind, we don't know how
840-
# to proceed so let the next parser figure it out
841-
if tok.value not in self._expr_operators:
842-
break
848+
if tok.value not in ("&&", "||"):
849+
break
843850

844-
rawtoks.append(tok)
851+
rawtoks.append(tok)
845852

846-
# check once more for compound operator?
847-
tok = self.lex.token()
848-
if tok.value in self._expr_operators:
849-
rawtoks.append(tok)
850-
tok = self.lex.token()
853+
tok = self.lex.token()
851854

852-
self.lex.return_token(tok)
855+
self.lex.return_token(tok)
853856

854857
return self._create_value(rawtoks)
855858

tests/test_concepts.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,59 @@ def test_requires_compound() -> None:
659659
)
660660

661661

662+
def test_requires_compound_parenthesized() -> None:
663+
content = """
664+
template <int X>
665+
requires (X == 0) || (X == 1)
666+
int Fibonacci() { return X; }
667+
"""
668+
data = parse_string(content, cleandoc=True)
669+
670+
assert data == ParsedData(
671+
namespace=NamespaceScope(
672+
functions=[
673+
Function(
674+
return_type=Type(
675+
typename=PQName(segments=[FundamentalSpecifier(name="int")])
676+
),
677+
name=PQName(segments=[NameSpecifier(name="Fibonacci")]),
678+
parameters=[],
679+
has_body=True,
680+
template=TemplateDecl(
681+
params=[
682+
TemplateNonTypeParam(
683+
type=Type(
684+
typename=PQName(
685+
segments=[FundamentalSpecifier(name="int")]
686+
)
687+
),
688+
name="X",
689+
)
690+
],
691+
raw_requires_pre=Value(
692+
tokens=[
693+
Token(value="("),
694+
Token(value="X"),
695+
Token(value="="),
696+
Token(value="="),
697+
Token(value="0"),
698+
Token(value=")"),
699+
Token(value="||"),
700+
Token(value="("),
701+
Token(value="X"),
702+
Token(value="="),
703+
Token(value="="),
704+
Token(value="1"),
705+
Token(value=")"),
706+
]
707+
),
708+
),
709+
)
710+
]
711+
)
712+
)
713+
714+
662715
def test_requires_ad_hoc() -> None:
663716
content = """
664717
template<typename T>

0 commit comments

Comments
 (0)