Skip to content

Commit 19d1588

Browse files
committed
[Macros] Handle macro overloading.
Allow more than one macro plugin to introduce a macro with the same name, and let the constraint solver figure out which one to call. Also eliminates a potential use-after-free if we somehow find additional compiler plugins to load after having expanded a macro.
1 parent bdf7762 commit 19d1588

File tree

6 files changed

+126
-19
lines changed

6 files changed

+126
-19
lines changed

include/swift/AST/ASTContext.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,6 @@ class ASTContext final {
352352
llvm::SmallPtrSet<DerivativeAttr *, 1>>
353353
DerivativeAttrs;
354354

355-
/// Cache of compiler plugins keyed by their name.
356-
llvm::StringMap<CompilerPlugin> LoadedPlugins;
357-
358-
/// Cache of loaded symbols.
359-
llvm::StringMap<void *> LoadedSymbols;
360-
361355
private:
362356
/// The current generation number, which reflects the number of
363357
/// times that external modules have been loaded.
@@ -1452,8 +1446,11 @@ class ASTContext final {
14521446
/// The declared interface type of Builtin.TheTupleType.
14531447
BuiltinTupleType *getBuiltinTupleType();
14541448

1455-
/// Finds the loaded compiler plugin given its name.
1456-
CompilerPlugin *getLoadedPlugin(StringRef name);
1449+
/// Finds the loaded compiler plugins with the given name.
1450+
TinyPtrVector<CompilerPlugin *> getLoadedPlugins(StringRef name);
1451+
1452+
/// Add a loaded plugin with the given name.
1453+
void addLoadedPlugin(StringRef name, CompilerPlugin *plugin);
14571454

14581455
/// Finds the address of the given symbol. If `libraryHandleHint` is non-null,
14591456
/// search within the library.

lib/AST/ASTContext.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,15 @@ struct ASTContext::Implementation {
515515

516516
llvm::StringMap<OptionSet<SearchPathKind>> SearchPathsSet;
517517

518+
/// Cache of compiler plugins keyed by their name.
519+
///
520+
/// Names can be overloaded, so there can be multiple plugins with the same
521+
/// name.
522+
llvm::StringMap<TinyPtrVector<CompilerPlugin*>> LoadedPlugins;
523+
524+
/// Cache of loaded symbols.
525+
llvm::StringMap<void *> LoadedSymbols;
526+
518527
/// The permanent arena.
519528
Arena Permanent;
520529

@@ -579,6 +588,12 @@ ASTContext::Implementation::Implementation()
579588
ASTContext::Implementation::~Implementation() {
580589
for (auto &cleanup : Cleanups)
581590
cleanup();
591+
592+
for (const auto &pluginsByName : LoadedPlugins) {
593+
for (auto plugin : pluginsByName.second) {
594+
delete plugin;
595+
}
596+
}
582597
}
583598

584599
ConstraintCheckerArenaRAII::
@@ -6047,16 +6062,21 @@ BuiltinTupleType *ASTContext::getBuiltinTupleType() {
60476062
return result;
60486063
}
60496064

6050-
CompilerPlugin *ASTContext::getLoadedPlugin(StringRef name) {
6051-
auto lookup = LoadedPlugins.find(name);
6052-
if (lookup == LoadedPlugins.end())
6053-
return nullptr;
6054-
return &lookup->second;
6065+
TinyPtrVector<CompilerPlugin *> ASTContext::getLoadedPlugins(StringRef name) {
6066+
auto &loadedPlugins = getImpl().LoadedPlugins;
6067+
auto lookup = loadedPlugins.find(name);
6068+
if (lookup == loadedPlugins.end())
6069+
return { };
6070+
return lookup->second;
6071+
}
6072+
6073+
void ASTContext::addLoadedPlugin(StringRef name, CompilerPlugin *plugin) {
6074+
getImpl().LoadedPlugins[name].push_back(plugin);
60556075
}
60566076

60576077
void *ASTContext::getAddressOfSymbol(const char *name,
60586078
void *libraryHandleHint) {
6059-
auto lookup = LoadedSymbols.try_emplace(name, nullptr);
6079+
auto lookup = getImpl().LoadedSymbols.try_emplace(name, nullptr);
60606080
void *&address = lookup.first->getValue();
60616081
#if !defined(_WIN32)
60626082
if (lookup.second) {

lib/AST/CompilerPlugin.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ void ASTContext::loadCompilerPlugins() {
188188
swift_ASTGen_getMacroTypes(getter, &metatypesAddress, &metatypeCount);
189189
ArrayRef<const void *> metatypes(metatypesAddress, metatypeCount);
190190
for (const void *metatype : metatypes) {
191-
CompilerPlugin plugin(metatype, lib, *this);
192-
auto name = plugin.getName();
193-
LoadedPlugins.try_emplace(name, std::move(plugin));
191+
auto plugin = new CompilerPlugin(metatype, lib, *this);
192+
auto name = plugin->getName();
193+
addLoadedPlugin(name, plugin);
194194
}
195195
free(const_cast<void *>((const void *)metatypes.data()));
196196
#endif // SWIFT_SWIFT_PARSER

lib/Sema/TypeCheckMacros.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ ArrayRef<MacroDecl *> MacroLookupRequest::evaluate(
317317
// Look for a loaded plugin based on the macro name.
318318
// FIXME: This API needs to be able to return multiple plugins, because
319319
// several plugins could export a macro with the same name.
320-
if (auto *plugin = ctx.getLoadedPlugin(macroName.str())) {
320+
for (auto plugin: ctx.getLoadedPlugins(macroName.str())) {
321321
if (auto pluginMacro = createPluginMacro(mod, macroName, plugin)) {
322322
macros.push_back(pluginMacro);
323323
}

test/Macros/Inputs/macro_definition.swift

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,95 @@ struct ColorLiteralMacro: _CompilerPlugin {
180180
}
181181
}
182182

183+
struct HSVColorLiteralMacro: _CompilerPlugin {
184+
static func _name() -> (UnsafePointer<UInt8>, count: Int) {
185+
var name = "customColorLiteral"
186+
return name.withUTF8 { buffer in
187+
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
188+
result.initialize(from: buffer.baseAddress!, count: buffer.count)
189+
return (UnsafePointer(result), count: buffer.count)
190+
}
191+
}
192+
193+
static func _genericSignature() -> (UnsafePointer<UInt8>?, count: Int) {
194+
var genSig = "<T>"
195+
return genSig.withUTF8 { buffer in
196+
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
197+
result.initialize(from: buffer.baseAddress!, count: buffer.count)
198+
return (UnsafePointer(result), count: buffer.count)
199+
}
200+
}
201+
202+
static func _typeSignature() -> (UnsafePointer<UInt8>, count: Int) {
203+
var typeSig =
204+
"""
205+
(
206+
hue hue: Float, saturation saturation: Float, value value: Float
207+
) -> T
208+
"""
209+
return typeSig.withUTF8 { buffer in
210+
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
211+
result.initialize(from: buffer.baseAddress!, count: buffer.count)
212+
return (UnsafePointer(result), count: buffer.count)
213+
}
214+
}
215+
216+
static func _owningModule() -> (UnsafePointer<UInt8>, count: Int) {
217+
var swiftModule = "Swift"
218+
return swiftModule.withUTF8 { buffer in
219+
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
220+
result.initialize(from: buffer.baseAddress!, count: buffer.count)
221+
return (UnsafePointer(result), count: buffer.count)
222+
}
223+
}
224+
225+
static func _supplementalSignatureModules() -> (UnsafePointer<UInt8>, count: Int) {
226+
var nothing = ""
227+
return nothing.withUTF8 { buffer in
228+
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
229+
result.initialize(from: buffer.baseAddress!, count: buffer.count)
230+
return (UnsafePointer(result), count: buffer.count)
231+
}
232+
}
233+
234+
static func _kind() -> _CompilerPluginKind {
235+
.expressionMacro
236+
}
237+
238+
static func _rewrite(
239+
targetModuleName: UnsafePointer<UInt8>,
240+
targetModuleNameCount: Int,
241+
filePath: UnsafePointer<UInt8>,
242+
filePathCount: Int,
243+
sourceFileText: UnsafePointer<UInt8>,
244+
sourceFileTextCount: Int,
245+
localSourceText: UnsafePointer<UInt8>,
246+
localSourceTextCount: Int
247+
) -> (UnsafePointer<UInt8>?, count: Int) {
248+
let meeTextBuffer = UnsafeBufferPointer(
249+
start: localSourceText, count: localSourceTextCount)
250+
let meeText = String(decoding: meeTextBuffer, as: UTF8.self)
251+
let prefix = "#customColorLiteral(hue:"
252+
guard meeText.starts(with: prefix), meeText.last == ")" else {
253+
return (nil, 0)
254+
}
255+
let expr = meeText.dropFirst(prefix.count).dropLast()
256+
var resultString = ".init(_colorLiteralHue:\(expr))"
257+
return resultString.withUTF8 { buffer in
258+
let result = UnsafeMutableBufferPointer<UInt8>.allocate(
259+
capacity: buffer.count + 1)
260+
_ = result.initialize(from: buffer)
261+
result[buffer.count] = 0
262+
return (UnsafePointer(result.baseAddress), buffer.count)
263+
}
264+
}
265+
}
266+
183267

184268
public var allMacros: [Any.Type] {
185269
[
186270
StringifyMacro.self,
187-
ColorLiteralMacro.self
271+
ColorLiteralMacro.self,
272+
HSVColorLiteralMacro.self
188273
]
189274
}

test/Macros/macro_plugin.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ let _ = #customStringify(["a", "b", "c"] + ["d", "e", "f"])
3737

3838
struct MyColor: _ExpressibleByColorLiteral {
3939
init(_colorLiteralRed red: Float, green: Float, blue: Float, alpha: Float) { }
40+
init(_colorLiteralHue hue: Float, saturation: Float, value: Float) { }
4041
}
4142

43+
// CHECK: (macro_expansion_expr type='MyColor' {{.*}} name=customColorLiteral
4244
let _: MyColor = #customColorLiteral(red: 0.5, green: 0.5, blue: 0.2, alpha: 0.9)
4345

46+
// CHECK: (macro_expansion_expr type='MyColor' {{.*}} name=customColorLiteral
47+
let _: MyColor = #customColorLiteral(hue: 0.5, saturation: 0.5, value: 0.2)
48+

0 commit comments

Comments
 (0)