Skip to content

Commit eee40b4

Browse files
committed
[CSGen] Make collection subscript result type inference more principled
Infer result type of a subscript with Array or Dictionary base type if argument type matches the key type exactly or it's a supported literal type. This helps to maintain the existing behavior without having to resort to "favored type" computation.
1 parent 979e046 commit eee40b4

File tree

2 files changed

+71
-45
lines changed

2 files changed

+71
-45
lines changed

lib/Sema/CSGen.cpp

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,61 @@ namespace {
10511051
return tv;
10521052
}
10531053

1054+
/// Attempt to infer a result type of a subscript reference where
1055+
/// the base type is either a stdlib Array or a Dictionary type.
1056+
/// This is a more principled version of the old performance hack
1057+
/// that used "favored" types deduced by the constraint optimizer
1058+
/// and is important to maintain pre-existing solver behavior.
1059+
Type inferCollectionSubscriptResultType(Type baseTy,
1060+
ArgumentList *argumentList) {
1061+
auto isLValueBase = false;
1062+
auto baseObjTy = baseTy;
1063+
if (baseObjTy->is<LValueType>()) {
1064+
isLValueBase = true;
1065+
baseObjTy = baseObjTy->getWithoutSpecifierType();
1066+
}
1067+
1068+
auto subscriptResultType = [&isLValueBase](Type valueTy,
1069+
bool isOptional) -> Type {
1070+
Type outputTy = isOptional ? OptionalType::get(valueTy) : valueTy;
1071+
return isLValueBase ? LValueType::get(outputTy) : outputTy;
1072+
};
1073+
1074+
if (auto *argument = argumentList->getUnlabeledUnaryExpr()) {
1075+
auto argumentTy = CS.getType(argument);
1076+
1077+
auto elementTy = baseObjTy->getArrayElementType();
1078+
1079+
if (!elementTy)
1080+
elementTy = baseObjTy->getInlineArrayElementType();
1081+
1082+
if (elementTy) {
1083+
if (auto arraySliceTy =
1084+
dyn_cast<ArraySliceType>(baseObjTy.getPointer())) {
1085+
baseObjTy = arraySliceTy->getDesugaredType();
1086+
}
1087+
1088+
if (argumentTy->isInt() || isExpr<IntegerLiteralExpr>(argument))
1089+
return subscriptResultType(elementTy, /*isOptional*/ false);
1090+
} else if (auto dictTy = CS.isDictionaryType(baseObjTy)) {
1091+
auto [keyTy, valueTy] = *dictTy;
1092+
1093+
if (keyTy->isString() &&
1094+
(isExpr<StringLiteralExpr>(argument) ||
1095+
isExpr<InterpolatedStringLiteralExpr>(argument)))
1096+
return subscriptResultType(valueTy, /*isOptional*/ true);
1097+
1098+
if (keyTy->isInt() && isExpr<IntegerLiteralExpr>(argument))
1099+
return subscriptResultType(valueTy, /*isOptional*/ true);
1100+
1101+
if (keyTy->isEqual(argumentTy))
1102+
return subscriptResultType(valueTy, /*isOptional*/ true);
1103+
}
1104+
}
1105+
1106+
return Type();
1107+
}
1108+
10541109
/// Add constraints for a subscript operation.
10551110
Type addSubscriptConstraints(
10561111
Expr *anchor, Type baseTy, ValueDecl *declOrNull, ArgumentList *argList,
@@ -1074,52 +1129,10 @@ namespace {
10741129

10751130
Type outputTy;
10761131

1077-
// For an integer subscript expression on an array slice type, instead of
1078-
// introducing a new type variable we can easily obtain the element type.
1079-
if (isa<SubscriptExpr>(anchor)) {
1080-
1081-
auto isLValueBase = false;
1082-
auto baseObjTy = baseTy;
1083-
if (baseObjTy->is<LValueType>()) {
1084-
isLValueBase = true;
1085-
baseObjTy = baseObjTy->getWithoutSpecifierType();
1086-
}
1087-
1088-
auto elementTy = baseObjTy->getArrayElementType();
1132+
// Attempt to infer the result type of a stdlib collection subscript.
1133+
if (isa<SubscriptExpr>(anchor))
1134+
outputTy = inferCollectionSubscriptResultType(baseTy, argList);
10891135

1090-
if (!elementTy)
1091-
elementTy = baseObjTy->getInlineArrayElementType();
1092-
1093-
if (elementTy) {
1094-
1095-
if (auto arraySliceTy =
1096-
dyn_cast<ArraySliceType>(baseObjTy.getPointer())) {
1097-
baseObjTy = arraySliceTy->getDesugaredType();
1098-
}
1099-
1100-
if (argList->isUnlabeledUnary() &&
1101-
isa<IntegerLiteralExpr>(argList->getExpr(0))) {
1102-
1103-
outputTy = elementTy;
1104-
1105-
if (isLValueBase)
1106-
outputTy = LValueType::get(outputTy);
1107-
}
1108-
} else if (auto dictTy = CS.isDictionaryType(baseObjTy)) {
1109-
auto keyTy = dictTy->first;
1110-
auto valueTy = dictTy->second;
1111-
1112-
if (argList->isUnlabeledUnary()) {
1113-
auto argTy = CS.getType(argList->getExpr(0));
1114-
if (isFavoredParamAndArg(CS, keyTy, argTy)) {
1115-
outputTy = OptionalType::get(valueTy);
1116-
if (isLValueBase)
1117-
outputTy = LValueType::get(outputTy);
1118-
}
1119-
}
1120-
}
1121-
}
1122-
11231136
if (outputTy.isNull()) {
11241137
outputTy = CS.createTypeVariable(resultLocator,
11251138
TVO_CanBindToLValue | TVO_CanBindToNoEscape);
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: %scale-test --begin 1 --end 15 --step 1 --select NumLeafScopes %s --expected-exit-code 0
2+
// REQUIRES: asserts,no_asan
3+
4+
func test(carrierDict: [String : Double]) {
5+
var exhaustTemperature: Double
6+
exhaustTemperature = (
7+
(carrierDict[""] ?? 0.0) +
8+
%for i in range(N):
9+
(carrierDict[""] ?? 0.0) +
10+
%end
11+
(carrierDict[""] ?? 0.0)
12+
) / 4
13+
}

0 commit comments

Comments
 (0)