Skip to content

Commit 8ec52c8

Browse files
committed
[Strict memory safety] Nested types are safe/unsafe independent of their enclosing type
When determining whether a nested type is safe, don't consider whether its enclosing type is safe. They're independent.
1 parent fe68567 commit 8ec52c8

File tree

3 files changed

+124
-16
lines changed

3 files changed

+124
-16
lines changed

lib/AST/ASTContext.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4190,6 +4190,55 @@ void UnboundGenericType::Profile(llvm::FoldingSetNodeID &ID,
41904190
ID.AddPointer(Parent.getPointer());
41914191
}
41924192

4193+
/// The safety of a parent type does not have an impact on a nested type within
4194+
/// it. This produces the recursive properties of a given type that should
4195+
/// be propagated to a nested type, which won't include any "IsUnsafe" bit
4196+
/// determined based on the declaration itself.
4197+
static RecursiveTypeProperties getRecursivePropertiesAsParent(Type type) {
4198+
if (!type)
4199+
return RecursiveTypeProperties();
4200+
4201+
// We only need to do anything interesting at all for unsafe types.
4202+
auto properties = type->getRecursiveProperties();
4203+
if (!properties.isUnsafe())
4204+
return properties;
4205+
4206+
if (auto nominal = type->getAnyNominal()) {
4207+
// If the nominal wasn't itself unsafe, then we got the unsafety from
4208+
// something else (e.g., a generic argument), so it won't change.
4209+
if (nominal->getExplicitSafety() != ExplicitSafety::Unsafe)
4210+
return properties;
4211+
}
4212+
4213+
// Drop the "unsafe" bit. We have to recompute it without considering the
4214+
// enclosing nominal type.
4215+
properties = RecursiveTypeProperties(
4216+
properties.getBits() & ~static_cast<unsigned>(RecursiveTypeProperties::IsUnsafe));
4217+
4218+
// Check generic arguments of parent types.
4219+
while (type) {
4220+
// Merge from the generic arguments.
4221+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
4222+
for (auto genericArg : boundGeneric->getGenericArgs())
4223+
properties |= genericArg->getRecursiveProperties();
4224+
}
4225+
4226+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>()) {
4227+
type = nominalOrBound->getParent();
4228+
continue;
4229+
}
4230+
4231+
if (auto unbound = type->getAs<UnboundGenericType>()) {
4232+
type = unbound->getParent();
4233+
continue;
4234+
}
4235+
4236+
break;
4237+
};
4238+
4239+
return properties;
4240+
}
4241+
41934242
UnboundGenericType *UnboundGenericType::
41944243
get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41954244
llvm::FoldingSetNodeID ID;
@@ -4198,7 +4247,7 @@ get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41984247
RecursiveTypeProperties properties;
41994248
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42004249
properties |= RecursiveTypeProperties::IsUnsafe;
4201-
if (Parent) properties |= Parent->getRecursiveProperties();
4250+
properties |= getRecursivePropertiesAsParent(Parent);
42024251

42034252
auto arena = getArena(properties);
42044253

@@ -4252,7 +4301,7 @@ BoundGenericType *BoundGenericType::get(NominalTypeDecl *TheDecl,
42524301
RecursiveTypeProperties properties;
42534302
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42544303
properties |= RecursiveTypeProperties::IsUnsafe;
4255-
if (Parent) properties |= Parent->getRecursiveProperties();
4304+
properties |= getRecursivePropertiesAsParent(Parent);
42564305
for (Type Arg : GenericArgs) {
42574306
properties |= Arg->getRecursiveProperties();
42584307
}
@@ -4335,7 +4384,7 @@ EnumType *EnumType::get(EnumDecl *D, Type Parent, const ASTContext &C) {
43354384
RecursiveTypeProperties properties;
43364385
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43374386
properties |= RecursiveTypeProperties::IsUnsafe;
4338-
if (Parent) properties |= Parent->getRecursiveProperties();
4387+
properties |= getRecursivePropertiesAsParent(Parent);
43394388
auto arena = getArena(properties);
43404389

43414390
auto *&known = C.getImpl().getArena(arena).EnumTypes[{D, Parent}];
@@ -4353,7 +4402,7 @@ StructType *StructType::get(StructDecl *D, Type Parent, const ASTContext &C) {
43534402
RecursiveTypeProperties properties;
43544403
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43554404
properties |= RecursiveTypeProperties::IsUnsafe;
4356-
if (Parent) properties |= Parent->getRecursiveProperties();
4405+
properties |= getRecursivePropertiesAsParent(Parent);
43574406
auto arena = getArena(properties);
43584407

43594408
auto *&known = C.getImpl().getArena(arena).StructTypes[{D, Parent}];
@@ -4371,7 +4420,7 @@ ClassType *ClassType::get(ClassDecl *D, Type Parent, const ASTContext &C) {
43714420
RecursiveTypeProperties properties;
43724421
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43734422
properties |= RecursiveTypeProperties::IsUnsafe;
4374-
if (Parent) properties |= Parent->getRecursiveProperties();
4423+
properties |= getRecursivePropertiesAsParent(Parent);
43754424
auto arena = getArena(properties);
43764425

43774426
auto *&known = C.getImpl().getArena(arena).ClassTypes[{D, Parent}];
@@ -5538,7 +5587,7 @@ ProtocolType *ProtocolType::get(ProtocolDecl *D, Type Parent,
55385587
RecursiveTypeProperties properties;
55395588
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
55405589
properties |= RecursiveTypeProperties::IsUnsafe;
5541-
if (Parent) properties |= Parent->getRecursiveProperties();
5590+
properties |= getRecursivePropertiesAsParent(Parent);
55425591
auto arena = getArena(properties);
55435592

55445593
auto *&known = C.getImpl().getArena(arena).ProtocolTypes[{D, Parent}];

lib/Sema/TypeCheckUnsafe.cpp

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,12 @@ bool swift::enumerateUnsafeUses(ArrayRef<ProtocolConformanceRef> conformances,
327327
bool swift::enumerateUnsafeUses(SubstitutionMap subs,
328328
SourceLoc loc,
329329
llvm::function_ref<bool(UnsafeUse)> fn) {
330-
// FIXME: Check replacement types?
330+
// Replacement types.
331+
for (auto replacementType : subs.getReplacementTypes()) {
332+
if (replacementType->isUnsafe() &&
333+
fn(UnsafeUse::forReferenceToUnsafe(nullptr, false, replacementType, loc)))
334+
return true;
335+
}
331336

332337
// Check conformances.
333338
if (enumerateUnsafeUses(subs.getConformances(), loc, fn))
@@ -376,24 +381,69 @@ void swift::diagnoseUnsafeType(ASTContext &ctx, SourceLoc loc, Type type,
376381
return;
377382

378383
// Look for a specific @unsafe nominal type along the way.
379-
auto findSpecificUnsafeType = [](Type type) {
384+
class Walker : public TypeWalker {
385+
public:
380386
Type specificType;
381-
(void)type.findIf([&specificType](Type type) {
387+
388+
Action walkToTypePre(Type type) override {
389+
if (specificType)
390+
return Action::Stop;
391+
392+
// If this refers to a nominal type that is @unsafe, store that.
382393
if (auto typeDecl = type->getAnyNominal()) {
383394
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
384395
specificType = type;
385-
return false;
396+
return Action::Stop;
386397
}
387398
}
388399

389-
return false;
390-
});
391-
return specificType;
400+
// Do not recurse into nominal types, because we do not want to visit
401+
// their "parent" types.
402+
if (isa<NominalOrBoundGenericNominalType>(type.getPointer()) ||
403+
isa<UnboundGenericType>(type.getPointer())) {
404+
// Recurse into the generic arguments. This operation is recursive,
405+
// because we also need to see the generic arguments of parent types.
406+
walkGenericArguments(type);
407+
408+
return Action::SkipNode;
409+
}
410+
411+
return Action::Continue;
412+
}
413+
414+
private:
415+
/// Recursively walk the generic arguments of this type and its parent
416+
/// types.
417+
void walkGenericArguments(Type type) {
418+
if (!type)
419+
return;
420+
421+
// Walk the generic arguments.
422+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
423+
for (auto genericArg : boundGeneric->getGenericArgs())
424+
genericArg.walk(*this);
425+
}
426+
427+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>())
428+
return walkGenericArguments(nominalOrBound->getParent());
429+
430+
if (auto unbound = type->getAs<UnboundGenericType>())
431+
return walkGenericArguments(unbound->getParent());
432+
}
392433
};
393434

394-
Type specificType = findSpecificUnsafeType(type);
395-
if (!specificType)
396-
specificType = findSpecificUnsafeType(type->getCanonicalType());
435+
// Look for a canonical unsafe type.
436+
Walker walker;
437+
type->getCanonicalType().walk(walker);
438+
Type specificType = walker.specificType;
439+
440+
// Look for an unsafe type in the non-canonical type, which is a better answer
441+
// if we can find it.
442+
walker.specificType = Type();
443+
type.walk(walker);
444+
if (specificType && walker.specificType &&
445+
specificType->isEqual(walker.specificType))
446+
specificType = walker.specificType;
397447

398448
diagnose(specificType ? specificType : type);
399449
}

test/Unsafe/safe.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,15 @@ struct UnsafeContainingUnspecified {
277277
typealias A = Int
278278

279279
func getA() -> A { 0 }
280+
281+
@safe
282+
struct Y {
283+
var value: Int
284+
}
285+
286+
func f() {
287+
_ = Y(value: 5)
288+
}
280289
}
281290

282291

0 commit comments

Comments
 (0)