Skip to content

Commit 222bc03

Browse files
committed
[Macros] Prevent recursive expansion of macros.
1 parent 6bb9cb8 commit 222bc03

File tree

4 files changed

+54
-1
lines changed

4 files changed

+54
-1
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6799,6 +6799,8 @@ ERROR(macro_expansion_missing_arguments,none,
67996799
"expansion of macro %0 requires arguments", (DeclName))
68006800
ERROR(macro_unsupported,none,
68016801
"macros are not supported in this compiler", ())
6802+
ERROR(macro_recursive,none,
6803+
"recursive expansion of macro %0", (DeclName))
68026804

68036805
//------------------------------------------------------------------------------
68046806
// MARK: Move Only Errors

lib/Sema/TypeCheckMacros.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,34 @@ MacroDefinition MacroDefinitionRequest::evaluate(
121121
ctx, macro->externalModuleName, macro->externalMacroTypeName);
122122
}
123123

124+
/// Determine whether the given source file is from an expansion of the given
125+
/// macro.
126+
static bool isFromExpansionOfMacro(SourceFile *sourceFile, MacroDecl *macro) {
127+
while (sourceFile) {
128+
auto expansion = sourceFile->getMacroExpansion();
129+
if (!expansion)
130+
return false;
131+
132+
if (auto expansionExpr = dyn_cast_or_null<MacroExpansionExpr>(
133+
expansion.dyn_cast<Expr *>())) {
134+
if (expansionExpr->getMacroRef().getDecl() == macro)
135+
return true;
136+
} else if (auto expansionDecl = dyn_cast_or_null<MacroExpansionDecl>(
137+
expansion.dyn_cast<Decl *>())) {
138+
// FIXME: Update once MacroExpansionDecl has a proper macro reference
139+
// in it.
140+
if (expansionDecl->getMacro().getFullName() == macro->getName())
141+
return true;
142+
} else {
143+
llvm_unreachable("Unknown macro expansion node kind");
144+
}
145+
146+
sourceFile = sourceFile->getEnclosingSourceFile();
147+
}
148+
149+
return false;
150+
}
151+
124152
Expr *swift::expandMacroExpr(
125153
DeclContext *dc, Expr *expr, ConcreteDeclRef macroRef, Type expandedType
126154
) {
@@ -137,6 +165,11 @@ Expr *swift::expandMacroExpr(
137165

138166
MacroDecl *macro = cast<MacroDecl>(macroRef.getDecl());
139167

168+
if (isFromExpansionOfMacro(sourceFile, macro)) {
169+
ctx.Diags.diagnose(expr->getLoc(), diag::macro_recursive, macro->getName());
170+
return nullptr;
171+
}
172+
140173
auto macroDef = evaluateOrDefault(
141174
ctx.evaluator, MacroDefinitionRequest{macro},
142175
MacroDefinition::forUndefined());

test/Macros/Inputs/syntax_macro_definitions.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,16 @@ public struct AddBlocker: ExpressionMacro {
154154
return result.as(MacroExpansionExprSyntax.self)!.argumentList.first!.expression
155155
}
156156
}
157+
158+
public struct RecursiveMacro: ExpressionMacro {
159+
public static func expansion(
160+
of macro: MacroExpansionExprSyntax, in context: inout MacroExpansionContext
161+
) -> ExprSyntax {
162+
guard let argument = macro.argumentList.first?.expression,
163+
argument.description == "false" else {
164+
return ExprSyntax(macro)
165+
}
166+
167+
return "()"
168+
}
169+
}

test/Macros/macro_expand.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
@expression macro customFileID: String = MacroDefinition.FileIDMacro
1313
@expression macro stringify<T>(_ value: T) -> (T, String) = MacroDefinition.StringifyMacro
14-
@expression macro fileID<T: _ExpressibleByStringLitera>: T = MacroDefinition.FileIDMacro
14+
@expression macro fileID<T: _ExpressibleByStringLiteral>: T = MacroDefinition.FileIDMacro
15+
@expression macro recurse(_: Bool) = MacroDefinition.RecursiveMacro
1516

1617
func testFileID(a: Int, b: Int) {
1718
// CHECK: MacroUser/macro_expand.swift
@@ -73,5 +74,9 @@ func testAddBlocker(a: Int, b: Int, c: Int, oa: OnlyAdds) {
7374
_ = #addBlocker(oa + oa) // expected-error{{blocked an add; did you mean to subtract? (from macro 'addBlocker')}}
7475
// expected-note@-1{{in expansion of macro 'addBlocker' here}}
7576
// expected-note@-2{{use '-'}}{{22-23=-}}
77+
78+
// Check recursion.
79+
#recurse(false) // okay
80+
#recurse(true) // expected-note{{in expansion of macro 'recurse' here}}
7681
#endif
7782
}

0 commit comments

Comments
 (0)