Skip to content

Commit 3044dd8

Browse files
Metal backend: Add MPSGraph caching (pytorch#15346)
This pull request introduces a new caching infrastructure for compiled MPSGraph objects in the Metal backend, significantly improving performance for repeated matrix multiplication and convolution operations by reusing previously compiled graphs and their associated tensors. The changes also include cache statistics tracking and refactoring of the relevant code paths to leverage the cache. Note that caching for the attention operation is not yet implemented. ### MPSGraph Caching Infrastructure * Added a new global cache (`graph_cache`) for compiled `MPSGraph` objects and their input/output tensors, keyed by operation type and parameters, enabling reuse and reducing graph compilation overhead. * Introduced `GraphCacheKey` and `CachedGraph` structures, along with a custom hash function, to uniquely identify and store cached graphs for matrix multiplication and convolution operations. ### Matrix Multiplication and Convolution Refactoring * Refactored the matrix multiplication (`aoti_torch_mps_mm_out`) and convolution (`aoti_torch_mps_convolution`) functions to check the cache for an existing compiled graph before creating a new one, and to store newly compiled graphs in the cache for future reuse. [[1]](diffhunk://#diff-8b30270bea48d11579e41d09a082cfad335aa3bbb302302c2320d4e0da6b4680R261-R303) [[2]](diffhunk://#diff-8b30270bea48d11579e41d09a082cfad335aa3bbb302302c2320d4e0da6b4680L505-R654) * Removed manual release of `MPSGraph` objects after execution, as cached graphs are now retained for reuse, preventing unnecessary recompilation and memory leaks. [[1]](diffhunk://#diff-8b30270bea48d11579e41d09a082cfad335aa3bbb302302c2320d4e0da6b4680L282-L285) [[2]](diffhunk://#diff-8b30270bea48d11579e41d09a082cfad335aa3bbb302302c2320d4e0da6b4680R895-L759) ### Cache Statistics and Monitoring * Added `CacheStats` structure to track cache hits and misses.
1 parent 526eb18 commit 3044dd8

File tree

1 file changed

+279
-137
lines changed

1 file changed

+279
-137
lines changed

0 commit comments

Comments
 (0)