Skip to content

Commit 06b1aee

Browse files
slavapestovkavon
authored andcommitted
Evaluator: Cache circular evaluation to avoid redundant diagnostics
Previously, if a request R evaluated itself N times, we would emit N "circular reference" diagnostics. These add no value, so instead let's cache the user-provided default value on the first circular evaluation. This changes things slightly so that instead of returning an llvm::Expected<Request::OutputType>, various evaluator methods take a callback which can produce the default value. The existing evaluateOrDefault() interface is unchanged, and a new evaluateOrFatal() entry point replaces llvm::cantFail(ctx.evaluator(...)). Direct callers of the evaluator's operator() were updated to pass in the callback. The benefit of the callback over evaluateOrDefault() is that if the default value is expensive to constuct, like a dummy generic signature, we will only construct it in the case where a cycle actually happened, otherwise we just delete the callback. (cherry picked from commit b8fcf1c)
1 parent 25355b2 commit 06b1aee

File tree

19 files changed

+137
-220
lines changed

19 files changed

+137
-220
lines changed

include/swift/AST/Evaluator.h

Lines changed: 42 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -70,47 +70,6 @@ class PrettyStackTraceRequest : public llvm::PrettyStackTraceEntry {
7070
}
7171
};
7272

73-
/// An llvm::ErrorInfo container for a request in which a cycle was detected
74-
/// and diagnosed.
75-
template <typename Request>
76-
struct CyclicalRequestError :
77-
public llvm::ErrorInfo<CyclicalRequestError<Request>> {
78-
public:
79-
static char ID;
80-
const Request &request;
81-
const Evaluator &evaluator;
82-
83-
CyclicalRequestError(const Request &request, const Evaluator &evaluator)
84-
: request(request), evaluator(evaluator) {}
85-
86-
virtual void log(llvm::raw_ostream &out) const override;
87-
88-
virtual std::error_code convertToErrorCode() const override {
89-
// This is essentially unused, but is a temporary requirement for
90-
// llvm::ErrorInfo subclasses.
91-
llvm_unreachable("shouldn't get std::error_code from CyclicalRequestError");
92-
}
93-
};
94-
95-
template <typename Request>
96-
char CyclicalRequestError<Request>::ID = '\0';
97-
98-
/// Evaluates a given request or returns a default value if a cycle is detected.
99-
template <typename Request>
100-
typename Request::OutputType
101-
evaluateOrDefault(
102-
Evaluator &eval, Request req, typename Request::OutputType def) {
103-
auto result = eval(req);
104-
if (auto err = result.takeError()) {
105-
llvm::handleAllErrors(std::move(err),
106-
[](const CyclicalRequestError<Request> &E) {
107-
// cycle detected
108-
});
109-
return def;
110-
}
111-
return *result;
112-
}
113-
11473
/// Report that a request of the given kind is being evaluated, so it
11574
/// can be recorded by the stats reporter.
11675
template<typename Request>
@@ -256,37 +215,26 @@ class Evaluator {
256215

257216
/// Retrieve the result produced by evaluating a request that can
258217
/// be cached.
259-
template<typename Request,
218+
template<typename Request, typename Fn,
260219
typename std::enable_if<Request::isEverCached>::type * = nullptr>
261-
llvm::Expected<typename Request::OutputType>
262-
operator()(const Request &request) {
220+
typename Request::OutputType
221+
operator()(const Request &request, Fn defaultValueFn) {
263222
// The request can be cached, but check a predicate to determine
264223
// whether this particular instance is cached. This allows more
265224
// fine-grained control over which instances get cache.
266225
if (request.isCached())
267-
return getResultCached(request);
226+
return getResultCached(request, std::move(defaultValueFn));
268227

269-
return getResultUncached(request);
228+
return getResultUncached(request, std::move(defaultValueFn));
270229
}
271230

272231
/// Retrieve the result produced by evaluating a request that
273232
/// will never be cached.
274-
template<typename Request,
233+
template<typename Request, typename Fn,
275234
typename std::enable_if<!Request::isEverCached>::type * = nullptr>
276-
llvm::Expected<typename Request::OutputType>
277-
operator()(const Request &request) {
278-
return getResultUncached(request);
279-
}
280-
281-
/// Evaluate a set of requests and return their results as a tuple.
282-
///
283-
/// Use this to describe cases where there are multiple (known)
284-
/// requests that all need to be satisfied.
285-
template<typename ...Requests>
286-
std::tuple<llvm::Expected<typename Requests::OutputType>...>
287-
operator()(const Requests &...requests) {
288-
return std::tuple<llvm::Expected<typename Requests::OutputType>...>(
289-
(*this)(requests)...);
235+
typename Request::OutputType
236+
operator()(const Request &request, Fn defaultValueFn) {
237+
return getResultUncached(request, std::move(defaultValueFn));
290238
}
291239

292240
/// Cache a precomputed value for the given request, so that it will not
@@ -304,7 +252,9 @@ class Evaluator {
304252
typename std::enable_if<!Request::hasExternalCache>::type* = nullptr>
305253
void cacheOutput(const Request &request,
306254
typename Request::OutputType &&output) {
307-
cache.insert<Request>(request, std::move(output));
255+
bool inserted = cache.insert<Request>(request, std::move(output));
256+
assert(inserted && "Request result was already cached");
257+
(void) inserted;
308258
}
309259

310260
template<typename Request,
@@ -351,15 +301,14 @@ class Evaluator {
351301
void finishedRequest(const ActiveRequest &request);
352302

353303
/// Produce the result of the request without caching.
354-
template<typename Request>
355-
llvm::Expected<typename Request::OutputType>
356-
getResultUncached(const Request &request) {
304+
template<typename Request, typename Fn>
305+
typename Request::OutputType
306+
getResultUncached(const Request &request, Fn defaultValueFn) {
357307
auto activeReq = ActiveRequest(request);
358308

359309
// Check for a cycle.
360310
if (checkDependency(activeReq)) {
361-
return llvm::Error(
362-
std::make_unique<CyclicalRequestError<Request>>(request, *this));
311+
return defaultValueFn();
363312
}
364313

365314
PrettyStackTraceRequest<Request> prettyStackTrace(request);
@@ -370,7 +319,7 @@ class Evaluator {
370319

371320
recorder.beginRequest<Request>();
372321

373-
auto &&result = getRequestFunction<Request>()(request, *this);
322+
auto result = getRequestFunction<Request>()(request, *this);
374323

375324
recorder.endRequest<Request>(request);
376325

@@ -381,16 +330,16 @@ class Evaluator {
381330
// done.
382331
finishedRequest(activeReq);
383332

384-
return std::move(result);
333+
return result;
385334
}
386335

387336
/// Get the result of a request, consulting an external cache
388337
/// provided by the request to retrieve previously-computed results
389338
/// and detect recursion.
390-
template<typename Request,
339+
template<typename Request, typename Fn,
391340
typename std::enable_if<Request::hasExternalCache>::type * = nullptr>
392-
llvm::Expected<typename Request::OutputType>
393-
getResultCached(const Request &request) {
341+
typename Request::OutputType
342+
getResultCached(const Request &request, Fn defaultValueFn) {
394343
// If there is a cached result, return it.
395344
if (auto cached = request.getCachedResult()) {
396345
recorder.replayCachedRequest(request);
@@ -399,13 +348,10 @@ class Evaluator {
399348
}
400349

401350
// Compute the result.
402-
auto result = getResultUncached(request);
351+
auto result = getResultUncached(request, std::move(defaultValueFn));
403352

404353
// Cache the result if applicable.
405-
if (!result)
406-
return result;
407-
408-
request.cacheResult(*result);
354+
request.cacheResult(result);
409355

410356
// Return it.
411357
return result;
@@ -414,10 +360,10 @@ class Evaluator {
414360
/// Get the result of a request, consulting the general cache to
415361
/// retrieve previously-computed results and detect recursion.
416362
template<
417-
typename Request,
363+
typename Request, typename Fn,
418364
typename std::enable_if<!Request::hasExternalCache>::type * = nullptr>
419-
llvm::Expected<typename Request::OutputType>
420-
getResultCached(const Request &request) {
365+
typename Request::OutputType
366+
getResultCached(const Request &request, Fn defaultValueFn) {
421367
// If we already have an entry for this request in the cache, return it.
422368
auto known = cache.find_as<Request>(request);
423369
if (known != cache.end<Request>()) {
@@ -428,12 +374,10 @@ class Evaluator {
428374
}
429375

430376
// Compute the result.
431-
auto result = getResultUncached(request);
432-
if (!result)
433-
return result;
377+
auto result = getResultUncached(request, std::move(defaultValueFn));
434378

435379
// Cache the result.
436-
cache.insert<Request>(request, *result);
380+
cache.insert<Request>(request, result);
437381
return result;
438382
}
439383

@@ -465,11 +409,20 @@ class Evaluator {
465409
}
466410
};
467411

468-
template <typename Request>
469-
void CyclicalRequestError<Request>::log(llvm::raw_ostream &out) const {
470-
out << "Cycle detected:\n";
471-
simple_display(out, request);
472-
out << "\n";
412+
/// Evaluates a given request or returns a default value if a cycle is detected.
413+
template<typename Request>
414+
typename Request::OutputType
415+
evaluateOrDefault(Evaluator &eval, Request req, typename Request::OutputType def) {
416+
return eval(req, [def]() { return def; });
417+
}
418+
419+
/// Evaluates a given request or returns a default value if a cycle is detected.
420+
template<typename Request>
421+
typename Request::OutputType
422+
evaluateOrFatal(Evaluator &eval, Request req) {
423+
return eval(req, []() -> typename Request::OutputType {
424+
llvm::report_fatal_error("Request cycle");
425+
});
473426
}
474427

475428
} // end namespace evaluator

include/swift/AST/RequestCache.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,11 @@ class RequestCache {
259259
}
260260

261261
template <typename Request>
262-
void insert(Request req, typename Request::OutputType val) {
262+
bool insert(Request req, typename Request::OutputType val) {
263263
auto *cache = getCache<Request>();
264264
auto result = cache->insert({RequestKey<Request>(std::move(req)),
265-
std::move(val)});
266-
assert(result.second && "Request result was already cached");
267-
(void) result;
265+
std::move(val)});
266+
return result.second;
268267
}
269268

270269
template <typename Request>

lib/AST/Decl.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,16 +1330,9 @@ static GenericSignature getPlaceholderGenericSignature(
13301330
GenericSignature GenericContext::getGenericSignature() const {
13311331
// Don't use evaluateOrDefault() here, because getting the 'default value'
13321332
// is slightly expensive here so we don't want to do it eagerly.
1333-
auto result = getASTContext().evaluator(
1334-
GenericSignatureRequest{const_cast<GenericContext *>(this)});
1335-
if (auto err = result.takeError()) {
1336-
llvm::handleAllErrors(std::move(err),
1337-
[](const CyclicalRequestError<GenericSignatureRequest> &E) {
1338-
// cycle detected
1339-
});
1340-
return getPlaceholderGenericSignature(this);
1341-
}
1342-
return *result;
1333+
return getASTContext().evaluator(
1334+
GenericSignatureRequest{const_cast<GenericContext *>(this)},
1335+
[this]() { return getPlaceholderGenericSignature(this); });
13431336
}
13441337

13451338
GenericEnvironment *GenericContext::getGenericEnvironment() const {
@@ -3866,12 +3859,8 @@ bool ValueDecl::isRecursiveValidation() const {
38663859

38673860
Type ValueDecl::getInterfaceType() const {
38683861
auto &ctx = getASTContext();
3869-
if (auto type =
3870-
evaluateOrDefault(ctx.evaluator,
3871-
InterfaceTypeRequest{const_cast<ValueDecl *>(this)},
3872-
Type()))
3873-
return type;
3874-
return ErrorType::get(ctx);
3862+
return ctx.evaluator(InterfaceTypeRequest{const_cast<ValueDecl *>(this)},
3863+
[&ctx]() { return ErrorType::get(ctx); });
38753864
}
38763865

38773866
void ValueDecl::setInterfaceType(Type type) {
@@ -6741,7 +6730,6 @@ bool ProtocolDecl::isComputingRequirementSignature() const {
67416730
}
67426731

67436732
void ProtocolDecl::setRequirementSignature(RequirementSignature requirementSig) {
6744-
assert(!RequirementSig && "requirement signature already set");
67456733
RequirementSig = requirementSig;
67466734
}
67476735

@@ -10919,8 +10907,14 @@ Type ClassDecl::getSuperclass() const {
1091910907

1092010908
ClassDecl *ClassDecl::getSuperclassDecl() const {
1092110909
ASTContext &ctx = getASTContext();
10922-
return evaluateOrDefault(ctx.evaluator,
10923-
SuperclassDeclRequest{const_cast<ClassDecl *>(this)}, nullptr);
10910+
auto result = evaluateOrDefault(ctx.evaluator,
10911+
SuperclassDeclRequest{const_cast<ClassDecl *>(this)},
10912+
const_cast<ClassDecl *>(this));
10913+
10914+
if (result == this)
10915+
return nullptr;
10916+
10917+
return result;
1092410918
}
1092510919

1092610920
void ClassDecl::setSuperclass(Type superclass) {

lib/AST/NameLookup.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3206,16 +3206,10 @@ SuperclassDeclRequest::evaluate(Evaluator &evaluator,
32063206
// inheritance hierarchy by evaluating its superclass. This forces the
32073207
// diagnostic at this point and then suppresses the superclass failure.
32083208
if (superclass) {
3209-
auto result = Ctx.evaluator(SuperclassDeclRequest{superclass});
3210-
bool hadCycle = false;
3211-
if (auto err = result.takeError()) {
3212-
llvm::handleAllErrors(std::move(err),
3213-
[&hadCycle](const CyclicalRequestError<SuperclassDeclRequest> &E) {
3214-
hadCycle = true;
3215-
});
3216-
3217-
if (hadCycle)
3218-
return nullptr;
3209+
if (evaluateOrDefault(Ctx.evaluator,
3210+
SuperclassDeclRequest{superclass},
3211+
superclass) == superclass) {
3212+
return nullptr;
32193213
}
32203214

32213215
return superclass;

lib/AST/TypeCheckRequests.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -438,23 +438,27 @@ bool WhereClauseOwner::visitRequirements(
438438
TypeResolutionStage stage,
439439
llvm::function_ref<bool(Requirement, RequirementRepr *)> callback)
440440
const && {
441-
auto &evaluator = dc->getASTContext().evaluator;
441+
auto &ctx = dc->getASTContext();
442+
auto &evaluator = ctx.evaluator;
442443
auto requirements = getRequirements();
443444
for (unsigned index : indices(requirements)) {
444445
// Resolve to a requirement.
445-
auto req = evaluator(RequirementRequest{*this, index, stage});
446-
if (req) {
447-
// Invoke the callback. If it returns true, we're done.
448-
if (callback(*req, &requirements[index]))
449-
return true;
450-
446+
bool hadCycle = false;
447+
auto req = evaluator(RequirementRequest{*this, index, stage},
448+
[&ctx, &hadCycle]() {
449+
// If we encounter a cycle, just make a fake
450+
// same-type requirement. We will skip it below.
451+
hadCycle = true;
452+
return Requirement(RequirementKind::SameType,
453+
ErrorType::get(ctx),
454+
ErrorType::get(ctx));
455+
});
456+
if (hadCycle)
451457
continue;
452-
}
453458

454-
llvm::handleAllErrors(
455-
req.takeError(), [](const CyclicalRequestError<RequirementRequest> &E) {
456-
// cycle detected
457-
});
459+
// Invoke the callback. If it returns true, we're done.
460+
if (callback(req, &requirements[index]))
461+
return true;
458462
}
459463

460464
return false;

lib/DriverTool/sil_llvm_gen_main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ int sil_llvm_gen_main(ArrayRef<const char *> argv, void *MainAddr) {
221221
auto &eval = CI.getASTContext().evaluator;
222222
auto desc = getDescriptor();
223223
desc.out = &outFile->getOS();
224-
auto generatedMod = llvm::cantFail(eval(OptimizedIRRequest{desc}));
224+
auto generatedMod = evaluateOrFatal(eval, OptimizedIRRequest{desc});
225225
if (!generatedMod)
226226
return 1;
227227

0 commit comments

Comments
 (0)