Skip to content

Commit 436ecb2

Browse files
committed
Use more of getConcreteReplacementForMemberSerializationRequirement
1 parent f91b12b commit 436ecb2

File tree

3 files changed

+94
-114
lines changed

3 files changed

+94
-114
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ getDistributedSerializationRequirementProtocols(
9797
/// If so, we can emit slightly nicer diagnostics.
9898
bool checkDistributedSerializationRequirementIsExactlyCodable(
9999
ASTContext &C,
100-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements);
100+
Type type);
101101

102102
/// Get the `SerializationRequirement`, explode it into the specific
103103
/// protocol requirements and insert them into `requirements`.

lib/AST/DistributedDecl.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,17 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement(
106106
return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl());
107107
}
108108

109-
/// === Maybe the value is declared in a protocol?
110-
if (auto protocol = DC->getSelfProtocolDecl()) {
109+
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
110+
->getDeclaredInterfaceType();
111+
112+
if (DC->getSelfProtocolDecl() || isa<ExtensionDecl>(DC)) {
111113
GenericSignature signature;
112114
if (auto *genericContext = member->getAsGenericContext()) {
113115
signature = genericContext->getGenericSignature();
114116
} else {
115117
signature = DC->getGenericSignatureOfContext();
116118
}
117119

118-
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
119-
->getDeclaredInterfaceType();
120-
121120
// Note that this may be null, e.g. if we're a distributed func inside
122121
// a protocol that did not declare a specific actor system requirement.
123122
return signature->getConcreteType(SerReqAssocType);
@@ -355,15 +354,24 @@ swift::getDistributedSerializationRequirements(
355354

356355
bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
357356
ASTContext &C,
358-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
357+
Type type) {
358+
if (!type)
359+
return false;
360+
361+
if (type->hasError())
362+
return false;
363+
359364
auto encodable = C.getProtocol(KnownProtocolKind::Encodable);
360365
auto decodable = C.getProtocol(KnownProtocolKind::Decodable);
361366

362-
if (allRequirements.size() != 2)
367+
auto layout = type->getExistentialLayout();
368+
auto protocols = layout.getProtocols();
369+
370+
if (protocols.size() != 2)
363371
return false;
364372

365-
return allRequirements.count(encodable) &&
366-
allRequirements.count(decodable);
373+
return std::count(protocols.begin(), protocols.end(), encodable) == 1 &&
374+
std::count(protocols.begin(), protocols.end(), decodable) == 1;
367375
}
368376

369377
/******************************************************************************/

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 76 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,13 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
376376

377377
static bool checkDistributedTargetResultType(
378378
ModuleDecl *module, ValueDecl *valueDecl,
379-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &serializationRequirements,
379+
Type serializationRequirement,
380380
bool diagnose) {
381381
auto &C = valueDecl->getASTContext();
382382

383+
if (!serializationRequirement || serializationRequirement->hasError())
384+
return false; // error of the type would be diagnosed elsewhere
385+
383386
Type resultType;
384387
if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
385388
resultType = func->mapTypeIntoContext(func->getResultInterfaceType());
@@ -394,36 +397,39 @@ static bool checkDistributedTargetResultType(
394397

395398
auto isCodableRequirement =
396399
checkDistributedSerializationRequirementIsExactlyCodable(
397-
C, serializationRequirements);
398-
399-
for(auto serializationReq : serializationRequirements) {
400-
auto conformance =
401-
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
402-
if (conformance.isInvalid()) {
403-
if (diagnose) {
404-
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
405-
"Codable" : // Codable is a typealias, easier to diagnose like that
406-
serializationReq->getNameStr();
407-
408-
auto diag = valueDecl->diagnose(
409-
diag::distributed_actor_target_result_not_codable,
410-
resultType,
411-
valueDecl,
412-
conformanceToSuggest
413-
);
414-
415-
if (isCodableRequirement) {
416-
if (auto resultNominalType = resultType->getAnyNominal()) {
417-
addCodableFixIt(resultNominalType, diag);
400+
C, serializationRequirement);
401+
402+
if (serializationRequirement && !serializationRequirement->hasError()) {
403+
auto srl = serializationRequirement->getExistentialLayout();
404+
for (auto serializationReq: srl.getProtocols()) {
405+
auto conformance =
406+
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
407+
if (conformance.isInvalid()) {
408+
if (diagnose) {
409+
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
410+
"Codable" : // Codable is a typealias, easier to diagnose like that
411+
serializationReq->getNameStr();
412+
413+
auto diag = valueDecl->diagnose(
414+
diag::distributed_actor_target_result_not_codable,
415+
resultType,
416+
valueDecl,
417+
conformanceToSuggest
418+
);
419+
420+
if (isCodableRequirement) {
421+
if (auto resultNominalType = resultType->getAnyNominal()) {
422+
addCodableFixIt(resultNominalType, diag);
423+
}
418424
}
419-
}
420-
} // end if: diagnose
421-
422-
return true;
425+
} // end if: diagnose
426+
427+
return true;
428+
}
423429
}
424430
}
425431

426-
return false;
432+
return false;
427433
}
428434

429435
bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) {
@@ -494,74 +500,35 @@ bool CheckDistributedFunctionRequest::evaluate(
494500
if (!C.getLoadedModule(C.Id_Distributed))
495501
return true;
496502

497-
// === All parameters and the result type must conform
498-
// SerializationRequirement
499-
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements;
500-
if (auto extension = dyn_cast<ExtensionDecl>(DC)) {
501-
auto actorOrProtocol = extension->getExtendedNominal();
502-
if (auto actor = dyn_cast<ClassDecl>(actorOrProtocol)) {
503-
assert(actor->isAnyActor());
504-
serializationRequirements = getDistributedSerializationRequirementProtocols(
505-
getDistributedActorSystemType(actor)->getAnyNominal(),
506-
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
507-
} else if (auto protocol = dyn_cast<ProtocolDecl>(actorOrProtocol)) {
508-
extractDistributedSerializationRequirements(
509-
C, protocol->getGenericRequirements(),
510-
/*into=*/serializationRequirements);
511-
extractDistributedSerializationRequirements(
512-
C, extension->getGenericRequirements(),
513-
/*into=*/serializationRequirements);
514-
} else {
515-
// ignore
516-
}
517-
} else if (auto actor = dyn_cast<ClassDecl>(DC)) {
518-
serializationRequirements = getDistributedSerializationRequirementProtocols(
519-
getDistributedActorSystemType(actor)->getAnyNominal(),
520-
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
521-
} else if (isa<ProtocolDecl>(DC)) {
522-
if (auto seqReqTy =
523-
getConcreteReplacementForMemberSerializationRequirement(func)) {
524-
auto layout = seqReqTy->getExistentialLayout();
525-
for (auto req : layout.getProtocols()) {
526-
serializationRequirements.insert(req);
527-
}
528-
}
529-
530-
// The distributed actor constrained protocol has no serialization requirements
531-
// or actor system defined, so these will only be enforced, by implementations
532-
// of DAs conforming to it, skip checks here.
533-
if (serializationRequirements.empty()) {
534-
return false;
535-
}
536-
} else {
537-
llvm_unreachable("Distributed function detected in type other than extension, "
538-
"distributed actor, or protocol! This should not be possible "
539-
", please file a bug.");
540-
}
541-
542-
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
543-
auto serializationRequirementIsCodable =
544-
checkDistributedSerializationRequirementIsExactlyCodable(
545-
C, serializationRequirements);
546-
547-
for (auto param : *func->getParameters()) {
548-
// --- Check parameters for 'Codable' conformance
549-
auto paramTy = func->mapTypeIntoContext(param->getInterfaceType());
550-
551-
for (auto req : serializationRequirements) {
552-
if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) {
553-
auto diag = func->diagnose(
554-
diag::distributed_actor_func_param_not_codable,
555-
param->getArgumentName().str(), param->getInterfaceType(),
556-
func->getDescriptiveKind(),
557-
serializationRequirementIsCodable ? "Codable"
558-
: req->getNameStr());
559-
560-
if (auto paramNominalTy = paramTy->getAnyNominal()) {
561-
addCodableFixIt(paramNominalTy, diag);
562-
} // else, no nominal type to suggest the fixit for, e.g. a closure
563-
564-
return true;
503+
Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func);
504+
for (auto param: *func->getParameters()) {
505+
506+
// --- Check the parameter conforming to serialization requirements
507+
if (serializationReqType && !serializationReqType->hasError()) {
508+
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
509+
auto serializationRequirementIsCodable =
510+
checkDistributedSerializationRequirementIsExactlyCodable(
511+
C, serializationReqType);
512+
513+
// --- Check parameters for 'SerializationRequirement' conformance
514+
auto paramTy = func->mapTypeIntoContext(param->getInterfaceType());
515+
516+
auto srl = serializationReqType->getExistentialLayout();
517+
for (auto req: srl.getProtocols()) {
518+
if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) {
519+
auto diag = func->diagnose(
520+
diag::distributed_actor_func_param_not_codable,
521+
param->getArgumentName().str(), param->getInterfaceType(),
522+
func->getDescriptiveKind(),
523+
serializationRequirementIsCodable ? "Codable"
524+
: req->getNameStr());
525+
526+
if (auto paramNominalTy = paramTy->getAnyNominal()) {
527+
addCodableFixIt(paramNominalTy, diag);
528+
} // else, no nominal type to suggest the fixit for, e.g. a closure
529+
530+
return true;
531+
}
565532
}
566533
}
567534

@@ -598,10 +565,12 @@ bool CheckDistributedFunctionRequest::evaluate(
598565
}
599566
}
600567

601-
// --- Result type must be either void or a codable type
602-
if (checkDistributedTargetResultType(module, func, serializationRequirements,
603-
/*diagnose=*/true)) {
604-
return true;
568+
if (serializationReqType && !serializationReqType->hasError()) {
569+
// --- Result type must be either void or a codable type
570+
if (checkDistributedTargetResultType(module, func, serializationReqType,
571+
/*diagnose=*/true)) {
572+
return true;
573+
}
605574
}
606575

607576
return false;
@@ -649,13 +618,15 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
649618
DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty();
650619
auto systemDecl = systemVar->getInterfaceType()->getAnyNominal();
651620

652-
auto serializationRequirements =
653-
getDistributedSerializationRequirementProtocols(
654-
systemDecl,
655-
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
621+
// auto serializationRequirements =
622+
// getDistributedSerializationRequirementProtocols(
623+
// systemDecl,
624+
// C.getProtocol(KnownProtocolKind::DistributedActorSystem));
625+
auto serializationRequirement =
626+
getConcreteReplacementForMemberSerializationRequirement(systemVar);
656627

657628
auto module = var->getModuleContext();
658-
if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) {
629+
if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) {
659630
return true;
660631
}
661632

@@ -762,6 +733,7 @@ bool TypeChecker::checkDistributedFunc(FuncDecl *func) {
762733
return swift::checkDistributedFunction(func);
763734
}
764735

736+
// TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks
765737
llvm::SmallPtrSet<ProtocolDecl *, 2>
766738
swift::getDistributedSerializationRequirementProtocols(
767739
NominalTypeDecl *nominal, ProtocolDecl *protocol) {

0 commit comments

Comments
 (0)