diff --git a/cxxheaderparser/lexer.py b/cxxheaderparser/lexer.py index 17caf92..04834c4 100644 --- a/cxxheaderparser/lexer.py +++ b/cxxheaderparser/lexer.py @@ -125,6 +125,8 @@ class PlyLexer: "register", "reinterpret_cast", "requires", + "__restrict__", + "restrict", "return", "short", "signed", diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index 54e36c9..6314a4b 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -2244,7 +2244,9 @@ def _parse_cv_ptr_or_fn( # nonptr_fn is for parsing function types directly in template specialization while True: - tok = self.lex.token_if("*", "const", "volatile", "(") + tok = self.lex.token_if( + "*", "const", "volatile", "__restrict__", "restrict", "(" + ) if not tok: break @@ -2260,6 +2262,10 @@ def _parse_cv_ptr_or_fn( if not isinstance(dtype, (Pointer, Type)): raise self._parse_error(tok) dtype.volatile = True + elif tok.type in ("__restrict__", "restrict"): + if not isinstance(dtype, (Pointer, Reference)): + raise self._parse_error(tok) + dtype.restrict = True elif nonptr_fn: # remove any inner grouping parens while True: @@ -2331,7 +2337,7 @@ def _parse_cv_ptr_or_fn( # peek at the next token and see if it's a paren. If so, it might # be a nasty function pointer - if self.lex.token_peek_if("("): + if self.lex.token_peek_if("(", "__restrict__", "restrict"): dtype = self._parse_cv_ptr_or_fn(dtype, nonptr_fn) return dtype diff --git a/cxxheaderparser/types.py b/cxxheaderparser/types.py index 0eff405..db0dbe5 100644 --- a/cxxheaderparser/types.py +++ b/cxxheaderparser/types.py @@ -336,25 +336,28 @@ class Pointer: const: bool = False volatile: bool = False + restrict: bool = False def format(self) -> str: c = " const" if self.const else "" v = " volatile" if self.volatile else "" + r = " __restrict__" if self.restrict else "" ptr_to = self.ptr_to if isinstance(ptr_to, (Array, FunctionType)): - return ptr_to.format_decl(f"(*{c}{v})") + return ptr_to.format_decl(f"(*{r}{c}{v})") else: - return f"{ptr_to.format()}*{c}{v}" + return f"{ptr_to.format()}*{r}{c}{v}" def format_decl(self, name: str): """Format as a named declaration""" c = " const" if self.const else "" v = " volatile" if self.volatile else "" + r = " __restrict__" if self.restrict else "" ptr_to = self.ptr_to if isinstance(ptr_to, (Array, FunctionType)): - return ptr_to.format_decl(f"(*{c}{v} {name})") + return ptr_to.format_decl(f"(*{r}{c}{v} {name})") else: - return f"{ptr_to.format()}*{c}{v} {name}" + return f"{ptr_to.format()}*{r}{c}{v} {name}" @dataclass @@ -364,13 +367,16 @@ class Reference: """ ref_to: typing.Union[Array, FunctionType, Pointer, Type] + restrict: bool = False def format(self) -> str: ref_to = self.ref_to + if isinstance(ref_to, Array): return ref_to.format_decl("(&)") else: - return f"{ref_to.format()}&" + r = " __restrict__" if self.restrict else "" + return f"{ref_to.format()}&{r}" def format_decl(self, name: str): """Format as a named declaration""" @@ -379,7 +385,8 @@ def format_decl(self, name: str): if isinstance(ref_to, Array): return ref_to.format_decl(f"(& {name})") else: - return f"{ref_to.format()}& {name}" + r = " __restrict__" if self.restrict else "" + return f"{ref_to.format()}&{r} {name}" @dataclass diff --git a/tests/test_fn.py b/tests/test_fn.py index d261032..09f2e34 100644 --- a/tests/test_fn.py +++ b/tests/test_fn.py @@ -139,6 +139,8 @@ def test_fn_pointer_params() -> None: int fn1(int *); int fn2(int *p); int fn3(int(*p)); + int fn4(int* __restrict__ p); + int fn5(int& __restrict__ p); """ data = parse_string(content, cleandoc=True) @@ -198,6 +200,44 @@ def test_fn_pointer_params() -> None: ) ], ), + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="int")]) + ), + name=PQName(segments=[NameSpecifier(name="fn4")]), + parameters=[ + Parameter( + name="p", + type=Pointer( + ptr_to=Type( + typename=PQName( + segments=[FundamentalSpecifier(name="int")] + ) + ), + restrict=True, + ), + ) + ], + ), + Function( + return_type=Type( + typename=PQName(segments=[FundamentalSpecifier(name="int")]) + ), + name=PQName(segments=[NameSpecifier(name="fn5")]), + parameters=[ + Parameter( + name="p", + type=Reference( + ref_to=Type( + typename=PQName( + segments=[FundamentalSpecifier(name="int")] + ) + ), + restrict=True, + ), + ) + ], + ), ] ) )