Skip to content

Commit cdcabd0

Browse files
committed
Allow attached macros to be applied to imported C declarations
The Clang importer maps arbitrary attributes spelled with `swift_attr("...")` over to Swift attributes, using the Swift parser to process those attributes. Extend this mechanism to allow `swift_attr` to refer to an attached macro, expanding that macro as needed. When a macro is applied to an imported declaration, that declaration is pretty-printed (from the C++ AST) to provide to the macro implementation. There are a few games we need to place to resolve the macro, and a few more to lazily perform pretty-printing and adjust source locations to get the right information to the macro, but this demonstrates that we could take this path. As an example, we use this mechanism to add an `async` version of a C function that delivers its result via completion handler, using the `@AddAsync` example macro implementation from the swift-syntax repository.
1 parent 5fa12d3 commit cdcabd0

File tree

7 files changed

+258
-3
lines changed

7 files changed

+258
-3
lines changed

lib/AST/NameLookup.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "swift/Basic/SourceManager.h"
4141
#include "swift/Basic/Statistic.h"
4242
#include "swift/ClangImporter/ClangImporterRequests.h"
43+
#include "swift/ClangImporter/ClangModule.h"
4344
#include "swift/Parse/Lexer.h"
4445
#include "swift/Strings.h"
4546
#include "clang/AST/DeclObjC.h"
@@ -1718,6 +1719,14 @@ SmallVector<MacroDecl *, 1> namelookup::lookupMacros(DeclContext *dc,
17181719
ctx.evaluator, UnqualifiedLookupRequest{moduleLookupDesc}, {});
17191720
auto foundTypeDecl = moduleLookup.getSingleTypeResult();
17201721
auto *moduleDecl = dyn_cast_or_null<ModuleDecl>(foundTypeDecl);
1722+
1723+
// When resolving macro names for imported entities, we look for any
1724+
// loaded module.
1725+
if (!moduleDecl && isa<ClangModuleUnit>(moduleScopeDC)) {
1726+
moduleDecl = ctx.getLoadedModule(moduleName.getBaseIdentifier());
1727+
moduleScopeDC = moduleDecl;
1728+
}
1729+
17211730
if (!moduleDecl)
17221731
return {};
17231732

lib/ClangImporter/ClangImporter.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4717,6 +4717,17 @@ bool ClangImporter::Implementation::lookupValue(SwiftLookupTable &table,
47174717
}
47184718
}
47194719

4720+
// Visit auxiliary declarations to check for name matches.
4721+
decl->visitAuxiliaryDecls([&](Decl *aux) {
4722+
if (auto auxValue = dyn_cast<ValueDecl>(aux)) {
4723+
if (auxValue->getName().matchesRef(name) &&
4724+
auxValue->getDeclContext()->isModuleScopeContext()) {
4725+
consumer.foundDecl(auxValue, DeclVisibilityKind::VisibleAtTopLevel);
4726+
anyMatching = true;
4727+
}
4728+
}
4729+
});
4730+
47204731
// If we have a declaration and nothing matched so far, try the names used
47214732
// in other versions of Swift.
47224733
if (auto clangDecl = entry.dyn_cast<clang::NamedDecl *>()) {

lib/Sema/TypeCheckMacros.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "swift/AST/ASTContext.h"
2424
#include "swift/AST/ASTMangler.h"
2525
#include "swift/AST/ASTNode.h"
26+
#include "swift/AST/ASTPrinter.h"
2627
#include "swift/AST/DiagnosticsFrontend.h"
2728
#include "swift/AST/Expr.h"
2829
#include "swift/AST/FreestandingMacroExpansion.h"
@@ -39,6 +40,7 @@
3940
#include "swift/Basic/Lazy.h"
4041
#include "swift/Basic/SourceManager.h"
4142
#include "swift/Basic/StringExtras.h"
43+
#include "swift/ClangImporter/ClangModule.h"
4244
#include "swift/Bridging/ASTGen.h"
4345
#include "swift/Bridging/Macros.h"
4446
#include "swift/Demangling/Demangler.h"
@@ -1021,7 +1023,10 @@ createMacroSourceFile(std::unique_ptr<llvm::MemoryBuffer> buffer,
10211023
auto macroSourceFile = new (ctx) SourceFile(
10221024
*dc->getParentModule(), SourceFileKind::MacroExpansion, macroBufferID,
10231025
/*parsingOpts=*/{}, /*isPrimary=*/false);
1024-
macroSourceFile->setImports(dc->getParentSourceFile()->getImports());
1026+
if (auto parentSourceFile = dc->getParentSourceFile())
1027+
macroSourceFile->setImports(parentSourceFile->getImports());
1028+
else if (isa<ClangModuleUnit>(dc->getModuleScopeContext()))
1029+
macroSourceFile->setImports({});
10251030
return macroSourceFile;
10261031
}
10271032

@@ -1346,8 +1351,44 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
13461351
if (!attrSourceFile)
13471352
return nullptr;
13481353

1349-
auto declSourceFile =
1354+
SourceFile *declSourceFile =
13501355
moduleDecl->getSourceFileContainingLocation(attachedTo->getStartLoc());
1356+
if (!declSourceFile && isa<ClangModuleUnit>(dc->getModuleScopeContext())) {
1357+
// Pretty-print the declaration into a buffer so we can macro-expand
1358+
// it.
1359+
// FIXME: Turn this into a request.
1360+
llvm::SmallString<128> buffer;
1361+
{
1362+
llvm::raw_svector_ostream out(buffer);
1363+
StreamPrinter printer(out);
1364+
attachedTo->print(
1365+
printer,
1366+
PrintOptions::printForDiagnostics(
1367+
AccessLevel::Public,
1368+
ctx.TypeCheckerOpts.PrintFullConvention));
1369+
}
1370+
1371+
// Create the buffer.
1372+
SourceManager &sourceMgr = ctx.SourceMgr;
1373+
auto bufferID = sourceMgr.addMemBufferCopy(buffer);
1374+
auto memBufferStartLoc = sourceMgr.getLocForBufferStart(bufferID);
1375+
sourceMgr.setGeneratedSourceInfo(
1376+
bufferID,
1377+
GeneratedSourceInfo{
1378+
GeneratedSourceInfo::PrettyPrinted,
1379+
CharSourceRange(),
1380+
CharSourceRange(memBufferStartLoc, buffer.size()),
1381+
ASTNode(const_cast<Decl *>(attachedTo)).getOpaqueValue(),
1382+
nullptr
1383+
}
1384+
);
1385+
1386+
// Create a source file to go with it.
1387+
declSourceFile = new (ctx)
1388+
SourceFile(*moduleDecl, SourceFileKind::Library, bufferID);
1389+
moduleDecl->addAuxiliaryFile(*declSourceFile);
1390+
}
1391+
13511392
if (!declSourceFile)
13521393
return nullptr;
13531394

@@ -1486,13 +1527,18 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
14861527
if (auto var = dyn_cast<VarDecl>(attachedTo))
14871528
searchDecl = var->getParentPatternBinding();
14881529

1530+
auto startLoc = searchDecl->getStartLoc();
1531+
if (startLoc.isInvalid() && isa<ClangModuleUnit>(dc->getModuleScopeContext())) {
1532+
startLoc = ctx.SourceMgr.getLocForBufferStart(*declSourceFile->getBufferID());
1533+
}
1534+
14891535
BridgedStringRef evaluatedSourceOut{nullptr, 0};
14901536
assert(!externalDef.isError());
14911537
swift_Macros_expandAttachedMacro(
14921538
&ctx.Diags, externalDef.get(), discriminator->c_str(),
14931539
extendedType.c_str(), conformanceList.c_str(), getRawMacroRole(role),
14941540
astGenAttrSourceFile, attr->AtLoc.getOpaquePointerValue(),
1495-
astGenDeclSourceFile, searchDecl->getStartLoc().getOpaquePointerValue(),
1541+
astGenDeclSourceFile, startLoc.getOpaquePointerValue(),
14961542
astGenParentDeclSourceFile, parentDeclLoc, &evaluatedSourceOut);
14971543
if (!evaluatedSourceOut.unbridged().data())
14981544
return nullptr;
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
void async_divide(double x, double y, void (* _Nonnull completionHandler)(double x))
2+
__attribute__((swift_attr("@ModuleUser.AddAsync")));

test/Inputs/clang-importer-sdk/usr/include/module.modulemap

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,7 @@ module IncompleteTypes {
153153
header "IncompleteTypes.h"
154154
export *
155155
}
156+
157+
module CompletionHandlerGlobals {
158+
header "completion_handler_globals.h"
159+
}

test/Macros/Inputs/syntax_macro_definitions.swift

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,163 @@ public enum LeftHandOperandFinderMacro: ExpressionMacro {
887887
}
888888
}
889889

890+
extension SyntaxCollection {
891+
mutating func removeLast() {
892+
self.remove(at: self.index(before: self.endIndex))
893+
}
894+
}
895+
896+
public struct AddAsyncMacro: PeerMacro {
897+
public static func expansion<
898+
Context: MacroExpansionContext,
899+
Declaration: DeclSyntaxProtocol
900+
>(
901+
of node: AttributeSyntax,
902+
providingPeersOf declaration: Declaration,
903+
in context: Context
904+
) throws -> [DeclSyntax] {
905+
906+
// Only on functions at the moment.
907+
guard var funcDecl = declaration.as(FunctionDeclSyntax.self) else {
908+
throw CustomError.message("@addAsync only works on functions")
909+
}
910+
911+
// This only makes sense for non async functions.
912+
if funcDecl.signature.effectSpecifiers?.asyncSpecifier != nil {
913+
throw CustomError.message(
914+
"@addAsync requires an non async function"
915+
)
916+
}
917+
918+
// This only makes sense void functions
919+
if let resultType = funcDecl.signature.returnClause?.type,
920+
resultType.as(IdentifierTypeSyntax.self)?.name.text != "Void" {
921+
throw CustomError.message(
922+
"@addAsync requires an function that returns void"
923+
)
924+
}
925+
926+
// Requires a completion handler block as last parameter
927+
let completionHandlerParameter = funcDecl
928+
.signature
929+
.parameterClause
930+
.parameters.last?
931+
.type.as(AttributedTypeSyntax.self)?
932+
.baseType.as(FunctionTypeSyntax.self)
933+
guard let completionHandlerParameter else {
934+
throw CustomError.message(
935+
"@AddAsync requires an function that has a completion handler as last parameter"
936+
)
937+
}
938+
939+
// Completion handler needs to return Void
940+
if completionHandlerParameter.returnClause.type.as(IdentifierTypeSyntax.self)?.name.text != "Void" {
941+
throw CustomError.message(
942+
"@AddAsync requires an function that has a completion handler that returns Void"
943+
)
944+
}
945+
946+
let returnType = completionHandlerParameter.parameters.first?.type
947+
948+
let isResultReturn = returnType?.children(viewMode: .all).first?.description == "Result"
949+
let successReturnType =
950+
if isResultReturn {
951+
returnType!.as(IdentifierTypeSyntax.self)!.genericArgumentClause?.arguments.first!.argument
952+
} else {
953+
returnType
954+
}
955+
956+
// Remove completionHandler and comma from the previous parameter
957+
var newParameterList = funcDecl.signature.parameterClause.parameters
958+
newParameterList.removeLast()
959+
var newParameterListLastParameter = newParameterList.last!
960+
newParameterList.removeLast()
961+
newParameterListLastParameter.trailingTrivia = []
962+
newParameterListLastParameter.trailingComma = nil
963+
newParameterList.append(newParameterListLastParameter)
964+
965+
// Drop the @AddAsync attribute from the new declaration.
966+
let newAttributeList = funcDecl.attributes.filter {
967+
guard case let .attribute(attribute) = $0,
968+
let attributeType = attribute.attributeName.as(IdentifierTypeSyntax.self),
969+
let nodeType = node.attributeName.as(IdentifierTypeSyntax.self)
970+
else {
971+
return true
972+
}
973+
974+
return attributeType.name.text != nodeType.name.text
975+
}
976+
977+
let callArguments: [String] = newParameterList.map { param in
978+
let argName = param.secondName ?? param.firstName
979+
980+
let paramName = param.firstName
981+
if paramName.text != "_" {
982+
return "\(paramName.text): \(argName.text)"
983+
}
984+
985+
return "\(argName.text)"
986+
}
987+
988+
let switchBody: ExprSyntax =
989+
"""
990+
switch returnValue {
991+
case .success(let value):
992+
continuation.resume(returning: value)
993+
case .failure(let error):
994+
continuation.resume(throwing: error)
995+
}
996+
"""
997+
998+
let newBody: ExprSyntax =
999+
"""
1000+
1001+
\(raw: isResultReturn ? "try await withCheckedThrowingContinuation { continuation in" : "await withCheckedContinuation { continuation in")
1002+
\(raw: funcDecl.name)(\(raw: callArguments.joined(separator: ", "))) { \(raw: returnType != nil ? "returnValue in" : "")
1003+
1004+
\(raw: isResultReturn ? switchBody : "continuation.resume(returning: \(raw: returnType != nil ? "returnValue" : "()"))")
1005+
}
1006+
}
1007+
1008+
"""
1009+
1010+
// add async
1011+
funcDecl.signature.effectSpecifiers = FunctionEffectSpecifiersSyntax(
1012+
leadingTrivia: .space,
1013+
asyncSpecifier: .keyword(.async),
1014+
throwsClause: isResultReturn ? ThrowsClauseSyntax(throwsSpecifier: .keyword(.throws)) : nil
1015+
)
1016+
1017+
// add result type
1018+
if let successReturnType {
1019+
funcDecl.signature.returnClause = ReturnClauseSyntax(
1020+
leadingTrivia: .space,
1021+
type: successReturnType.with(\.leadingTrivia, .space)
1022+
)
1023+
} else {
1024+
funcDecl.signature.returnClause = nil
1025+
}
1026+
1027+
// drop completion handler
1028+
funcDecl.signature.parameterClause.parameters = newParameterList
1029+
funcDecl.signature.parameterClause.trailingTrivia = []
1030+
1031+
funcDecl.body = CodeBlockSyntax(
1032+
leftBrace: .leftBraceToken(leadingTrivia: .space),
1033+
statements: CodeBlockItemListSyntax(
1034+
[CodeBlockItemSyntax(item: .expr(newBody))]
1035+
),
1036+
rightBrace: .rightBraceToken(leadingTrivia: .newline)
1037+
)
1038+
1039+
funcDecl.attributes = newAttributeList
1040+
1041+
funcDecl.leadingTrivia = .newlines(2)
1042+
1043+
return [DeclSyntax(funcDecl)]
1044+
}
1045+
}
1046+
8901047
public struct AddCompletionHandler: PeerMacro {
8911048
public static func expansion(
8921049
of node: AttributeSyntax,

test/Macros/expand_on_imported.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// REQUIRES: swift_swift_parser, executable_test
2+
3+
// RUN: %empty-directory(%t)
4+
// RUN: %host-build-swift -swift-version 5 -emit-library -o %t/%target-library-name(MacroDefinition) -module-name=MacroDefinition %S/Inputs/syntax_macro_definitions.swift -g -no-toolchain-stdlib-rpath -swift-version 5
5+
6+
// Diagnostics testing
7+
// RUN: %target-swift-frontend(mock-sdk: %clang-importer-sdk) -typecheck -verify -swift-version 5 -enable-experimental-feature CodeItemMacros -load-plugin-library %t/%target-library-name(MacroDefinition) -module-name ModuleUser %s
8+
9+
@attached(peer, names: overloaded)
10+
public macro AddAsync() = #externalMacro(module: "MacroDefinition", type: "AddAsyncMacro")
11+
12+
import CompletionHandlerGlobals
13+
14+
// Make sure that @AddAsync works at all.
15+
@AddAsync
16+
@available(SwiftStdlib 5.1, *)
17+
func asyncTest(_ value: Int, completionHandler: @escaping (String) -> Void) {
18+
completionHandler(String(value))
19+
}
20+
21+
@available(SwiftStdlib 5.1, *)
22+
func testAll(x: Double, y: Double) async {
23+
_ = await asyncTest(17)
24+
25+
let _: Double = await async_divide(1.0, 2.0)
26+
}

0 commit comments

Comments
 (0)