Skip to content

Commit 005743c

Browse files
committed
[Distributed] implement adhoc requirements properly for Encoder
1 parent 0fec1b3 commit 005743c

11 files changed

+592
-143
lines changed

include/swift/AST/ASTContext.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -693,15 +693,15 @@ class ASTContext final {
693693
// \param nominal optionally provide a 'NominalTypeDecl' from which the
694694
// function decl shall be extracted. This is useful to avoid witness calls
695695
// through the protocol which is looked up when nominal is null.
696-
FuncDecl *getRecordArgumentOnDistributedInvocationEncoder(
696+
AbstractFunctionDecl *getRecordArgumentOnDistributedInvocationEncoder(
697697
NominalTypeDecl *nominal) const;
698698

699-
// Retrieve the declaration of DistributedInvocationEncoder.recordErrorType(_:).
700-
FuncDecl *getRecordErrorTypeOnDistributedInvocationEncoder(
699+
// Retrieve the declaration of DistributedInvocationEncoder.recordReturnType(_:).
700+
AbstractFunctionDecl *getRecordReturnTypeOnDistributedInvocationEncoder(
701701
NominalTypeDecl *nominal) const;
702702

703-
// Retrieve the declaration of DistributedInvocationEncoder.recordReturnType(_:).
704-
FuncDecl *getRecordReturnTypeOnDistributedInvocationEncoder(
703+
// Retrieve the declaration of DistributedInvocationEncoder.recordErrorType(_:).
704+
AbstractFunctionDecl *getRecordErrorTypeOnDistributedInvocationEncoder(
705705
NominalTypeDecl *nominal) const;
706706

707707
// Retrieve the declaration of DistributedInvocationEncoder.doneRecording().
@@ -1351,14 +1351,15 @@ class ASTContext final {
13511351
/// alternative specified via the -entry-point-function-name frontend flag.
13521352
std::string getEntryPointFunctionName() const;
13531353

1354-
Type getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
1354+
Type getAssociatedTypeOfDistributedSystemOfActor(NominalTypeDecl *actor,
13551355
Identifier member);
13561356

1357-
/// Find the type of SerializationRequirement on the passed nominal.
1358-
///
1359-
/// This type exists as a typealias/associatedtype on all distributed actors,
1360-
/// actor systems, and related serialization types.
1361-
Type getDistributedSerializationRequirementType(NominalTypeDecl *);
1357+
// /// Find the type of SerializationRequirement on the passed nominal.
1358+
// ///
1359+
// /// This type exists as a typealias/associatedtype on all distributed actors,
1360+
// /// actor systems, and related serialization types.
1361+
// Type getDistributedSerializationRequirementType(
1362+
// NominalTypeDecl *, ProtocolDecl *protocolDecl);
13621363

13631364
/// Find the concrete invocation decoder associated with the given actor.
13641365
NominalTypeDecl *

include/swift/AST/Decl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6422,6 +6422,21 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
64226422
/// 'DistributedActorSystem' protocol.
64236423
bool isDistributedActorSystemRemoteCall(bool isVoidReturn) const;
64246424

6425+
/// Determines if this function is a 'recordArgument' function,
6426+
/// which is used as ad-hoc protocol requirement by the
6427+
/// 'DistributedTargetInvocationEncoder' protocol.
6428+
bool isDistributedTargetInvocationEncoderRecordArgument() const;
6429+
6430+
/// Determines if this function is a 'recordReturnType' function,
6431+
/// which is used as ad-hoc protocol requirement by the
6432+
/// 'DistributedTargetInvocationEncoder' protocol.
6433+
bool isDistributedTargetInvocationEncoderRecordReturnType() const;
6434+
6435+
/// Determines if this function is a 'recordErrorType' function,
6436+
/// which is used as ad-hoc protocol requirement by the
6437+
/// 'DistributedTargetInvocationEncoder' protocol.
6438+
bool isDistributedTargetInvocationEncoderRecordErrorType() const;
6439+
64256440
/// For a method of a class, checks whether it will require a new entry in the
64266441
/// vtable.
64276442
bool needsNewVTableEntry() const;

include/swift/AST/DistributedDecl.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,20 @@ Type getDistributedActorSystemType(NominalTypeDecl *actor);
3737
/// Determine the `ID` type for the given actor.
3838
Type getDistributedActorIDType(NominalTypeDecl *actor);
3939

40-
Type getDistributedActorSystemSerializationRequirementType(
41-
NominalTypeDecl *system);
40+
/// Get specific 'SerializationRequirement' as defined in 'nominal'
41+
/// type, which must conform to the passed 'protocol' which is expected
42+
/// to require the 'SerializationRequirement'.
43+
Type getDistributedSerializationRequirementType(
44+
NominalTypeDecl *nominal, ProtocolDecl *protocol);
45+
46+
///// Determine the serialization requirement for the given actor, actor system
47+
///// or other type that has the SerializationRequirement associated type.
48+
//Type getDistributedSerializationRequirementType(
49+
// NominalTypeDecl *nominal, ProtocolDecl *protocol);
4250

4351
Type getDistributedActorSystemActorIDRequirementType(
4452
NominalTypeDecl *system);
4553

46-
/// Determine the serialization requirement for the given actor, actor system
47-
/// or other type that has the SerializationRequirement associated type.
48-
Type getDistributedSerializationRequirementType(NominalTypeDecl *actor);
4954

5055
/// Get the specific protocols that the `SerializationRequirement` specifies,
5156
/// and all parameters / return types of distributed targets must conform to.
@@ -55,7 +60,8 @@ Type getDistributedSerializationRequirementType(NominalTypeDecl *actor);
5560
///
5661
/// Returns an empty set if the requirement was `Any`.
5762
llvm::SmallPtrSet<ProtocolDecl *, 2>
58-
getDistributedSerializationRequirementProtocols(NominalTypeDecl *decl);
63+
getDistributedSerializationRequirementProtocols(
64+
NominalTypeDecl *decl, ProtocolDecl* protocol);
5965

6066
/// Desugar and flatten the `SerializationRequirement` type into a set of
6167
/// specific protocol declarations.
@@ -78,6 +84,7 @@ bool checkDistributedSerializationRequirementIsExactlyCodable(
7884
bool
7985
getDistributedActorSystemSerializationRequirements(
8086
NominalTypeDecl *systemDecl,
87+
ProtocolDecl *protocol,
8188
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos);
8289

8390
/// Given any set of generic requirements, locate those which are about the

include/swift/AST/TypeCheckRequests.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,60 @@ class GetDistributedActorSystemRemoteCallFunctionRequest :
10551055
bool isCached() const { return true; }
10561056
};
10571057

1058+
/// Obtain the 'recordArgument' function of a 'DistributedTargetInvocationEncoder'.
1059+
class GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest :
1060+
public SimpleRequest<GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest,
1061+
AbstractFunctionDecl *(NominalTypeDecl *),
1062+
RequestFlags::Cached> {
1063+
public:
1064+
using SimpleRequest::SimpleRequest;
1065+
1066+
private:
1067+
friend SimpleRequest;
1068+
1069+
AbstractFunctionDecl *evaluate(Evaluator &evaluator, NominalTypeDecl *encoder) const;
1070+
1071+
public:
1072+
// Caching
1073+
bool isCached() const { return true; }
1074+
};
1075+
1076+
/// Obtain the 'recordReturnType' function of a 'DistributedTargetInvocationEncoder'.
1077+
class GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest :
1078+
public SimpleRequest<GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest,
1079+
AbstractFunctionDecl *(NominalTypeDecl *),
1080+
RequestFlags::Cached> {
1081+
public:
1082+
using SimpleRequest::SimpleRequest;
1083+
1084+
private:
1085+
friend SimpleRequest;
1086+
1087+
AbstractFunctionDecl *evaluate(Evaluator &evaluator, NominalTypeDecl *encoder) const;
1088+
1089+
public:
1090+
// Caching
1091+
bool isCached() const { return true; }
1092+
};
1093+
1094+
/// Obtain the 'recordErrorType' function of a 'DistributedTargetInvocationEncoder'.
1095+
class GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest :
1096+
public SimpleRequest<GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest,
1097+
AbstractFunctionDecl *(NominalTypeDecl *),
1098+
RequestFlags::Cached> {
1099+
public:
1100+
using SimpleRequest::SimpleRequest;
1101+
1102+
private:
1103+
friend SimpleRequest;
1104+
1105+
AbstractFunctionDecl *evaluate(Evaluator &evaluator, NominalTypeDecl *encoder) const;
1106+
1107+
public:
1108+
// Caching
1109+
bool isCached() const { return true; }
1110+
};
1111+
10581112
/// Obtain the 'actorSystem' property of a 'distributed actor'.
10591113
class GetDistributedActorSystemPropertyRequest :
10601114
public SimpleRequest<GetDistributedActorSystemPropertyRequest,

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ SWIFT_REQUEST(TypeChecker, IsDistributedActorRequest, bool(NominalTypeDecl *),
108108
SWIFT_REQUEST(TypeChecker, GetDistributedActorSystemRemoteCallFunctionRequest,
109109
AbstractFunctionDecl *(NominalTypeDecl *, bool),
110110
Cached, NoLocationInfo)
111+
SWIFT_REQUEST(TypeChecker, GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest,
112+
AbstractFunctionDecl *(NominalTypeDecl *),
113+
Cached, NoLocationInfo)
114+
SWIFT_REQUEST(TypeChecker, GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest,
115+
AbstractFunctionDecl *(NominalTypeDecl *),
116+
Cached, NoLocationInfo)
117+
SWIFT_REQUEST(TypeChecker, GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest,
118+
AbstractFunctionDecl *(NominalTypeDecl *),
119+
Cached, NoLocationInfo)
111120
SWIFT_REQUEST(TypeChecker, GetDistributedActorIDPropertyRequest,
112121
VarDecl *(NominalTypeDecl *),
113122
Cached, NoLocationInfo)

lib/AST/ASTContext.cpp

Lines changed: 15 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,67 +1358,28 @@ ASTContext::getRecordGenericSubstitutionOnDistributedInvocationEncoder(
13581358
return nullptr;
13591359
}
13601360

1361-
FuncDecl *ASTContext::getRecordArgumentOnDistributedInvocationEncoder(
1361+
AbstractFunctionDecl *ASTContext::getRecordArgumentOnDistributedInvocationEncoder(
13621362
NominalTypeDecl *nominal) const {
1363-
for (auto result : nominal->lookupDirect(Id_recordArgument)) {
1364-
auto *fd = dyn_cast<FuncDecl>(result);
1365-
if (!fd)
1366-
continue;
1367-
if (fd->getParameters()->size() != 1)
1368-
continue;
1369-
if (fd->hasAsync())
1370-
continue;
1371-
if (!fd->hasThrows())
1372-
continue;
1373-
// TODO(distributed): more checks
1374-
1375-
if (fd->getResultInterfaceType()->isVoid())
1376-
return fd;
1377-
}
1378-
1379-
return nullptr;
1363+
return evaluateOrDefault(
1364+
nominal->getASTContext().evaluator,
1365+
GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest{nominal},
1366+
nullptr);
13801367
}
13811368

1382-
FuncDecl *ASTContext::getRecordErrorTypeOnDistributedInvocationEncoder(
1369+
AbstractFunctionDecl *ASTContext::getRecordReturnTypeOnDistributedInvocationEncoder(
13831370
NominalTypeDecl *nominal) const {
1384-
for (auto result : nominal->lookupDirect(Id_recordErrorType)) {
1385-
auto *fd = dyn_cast<FuncDecl>(result);
1386-
if (!fd)
1387-
continue;
1388-
if (fd->getParameters()->size() != 1)
1389-
continue;
1390-
if (fd->hasAsync())
1391-
continue;
1392-
if (!fd->hasThrows())
1393-
continue;
1394-
// TODO(distributed): more checks
1395-
1396-
if (fd->getResultInterfaceType()->isVoid())
1397-
return fd;
1398-
}
1399-
1400-
return nullptr;
1371+
return evaluateOrDefault(
1372+
nominal->getASTContext().evaluator,
1373+
GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest{nominal},
1374+
nullptr);
14011375
}
14021376

1403-
FuncDecl *ASTContext::getRecordReturnTypeOnDistributedInvocationEncoder(
1377+
AbstractFunctionDecl *ASTContext::getRecordErrorTypeOnDistributedInvocationEncoder(
14041378
NominalTypeDecl *nominal) const {
1405-
for (auto result : nominal->lookupDirect(Id_recordReturnType)) {
1406-
auto *fd = dyn_cast<FuncDecl>(result);
1407-
if (!fd)
1408-
continue;
1409-
if (fd->getParameters()->size() != 1)
1410-
continue;
1411-
if (fd->hasAsync())
1412-
continue;
1413-
if (!fd->hasThrows())
1414-
continue;
1415-
// TODO(distributed): more checks
1416-
1417-
if (fd->getResultInterfaceType()->isVoid())
1418-
return fd;
1419-
}
1420-
1421-
return nullptr;
1379+
return evaluateOrDefault(
1380+
nominal->getASTContext().evaluator,
1381+
GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest{nominal},
1382+
nullptr);
14221383
}
14231384

14241385
FuncDecl *ASTContext::getDoneRecordingOnDistributedInvocationEncoder(

0 commit comments

Comments
 (0)