Skip to content

Commit 5ce4814

Browse files
misrasaurabh1codeflash-ai[bot]github-actions[bot]a-r-r-o-waseembits93
authored
⚡️ Speed up method AutoencoderKLWan.clear_cache by 886% (huggingface#11665)
* ⚡️ Speed up method `AutoencoderKLWan.clear_cache` by 886% **Key optimizations:** - Compute the number of `WanCausalConv3d` modules in each model (`encoder`/`decoder`) **only once during initialization**, store in `self._cached_conv_counts`. This removes unnecessary repeated tree traversals at every `clear_cache` call, which was the main bottleneck (from profiling). - The internal helper `_count_conv3d_fast` is optimized via a generator expression with `sum` for efficiency. All comments from the original code are preserved, except for updated or removed local docstrings/comments relevant to changed lines. **Function signatures and outputs remain unchanged.** * Apply style fixes * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * Apply style fixes --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aryan <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: Aseem Saxena <[email protected]>
1 parent 1bc6f3d commit 5ce4814

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,16 @@ def __init__(
749749
self.tile_sample_stride_height = 192
750750
self.tile_sample_stride_width = 192
751751

752+
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
753+
self._cached_conv_counts = {
754+
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
755+
if self.decoder is not None
756+
else 0,
757+
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
758+
if self.encoder is not None
759+
else 0,
760+
}
761+
752762
def enable_tiling(
753763
self,
754764
tile_sample_min_height: Optional[int] = None,
@@ -801,18 +811,12 @@ def disable_slicing(self) -> None:
801811
self.use_slicing = False
802812

803813
def clear_cache(self):
804-
def _count_conv3d(model):
805-
count = 0
806-
for m in model.modules():
807-
if isinstance(m, WanCausalConv3d):
808-
count += 1
809-
return count
810-
811-
self._conv_num = _count_conv3d(self.decoder)
814+
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
815+
self._conv_num = self._cached_conv_counts["decoder"]
812816
self._conv_idx = [0]
813817
self._feat_map = [None] * self._conv_num
814818
# cache encode
815-
self._enc_conv_num = _count_conv3d(self.encoder)
819+
self._enc_conv_num = self._cached_conv_counts["encoder"]
816820
self._enc_conv_idx = [0]
817821
self._enc_feat_map = [None] * self._enc_conv_num
818822

0 commit comments

Comments
 (0)