Skip to content

Commit 017cfc3

Browse files
committed
Refactor: Remove einops dependency in Magi1 VAE
Replaces the `rearrange` function with equivalent native PyTorch `permute` and `reshape` operations. This change removes the external `einops` library dependency, simplifying the model's environment.
1 parent 87299a4 commit 017cfc3

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_magi1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20-
from einops import rearrange
2120
from timm.layers import trunc_normal_
2221

2322
from ...configuration_utils import ConfigMixin, register_to_config
@@ -293,7 +292,9 @@ def forward(self, x):
293292
self.patch_size[2],
294293
self.unpatch_channels,
295294
)
296-
x = rearrange(x, "B lT lH lW pT pH pW C -> B C (lT pT) (lH pH) (lW pW)", C=self.unpatch_channels)
295+
# Rearrange from (B, lT, lH, lW, pT, pH, pW, C) to (B, C, lT*pT, lH*pH, lW*pW)
296+
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) # (B, C, lT, pT, lH, pH, lW, pW)
297+
x = x.reshape(B, self.unpatch_channels, latentT * self.patch_size[0], latentH * self.patch_size[1], latentW * self.patch_size[2])
297298

298299
x = self.conv_out(x)
299300
return x

0 commit comments

Comments
 (0)