Skip to content

Commit c96bfa5

Browse files
authored
[Mochi-1] ensuring to compute the fourier features in FP32 in Mochi encoder (huggingface#10031)
compute fourier features in FP32.
1 parent 6b288ec commit c96bfa5

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
437437

438438
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
439439
r"""Forward method of the `FourierFeatures` class."""
440-
440+
original_dtype = inputs.dtype
441+
inputs = inputs.to(torch.float32)
441442
num_channels = inputs.shape[1]
442443
num_freqs = (self.stop - self.start) // self.step
443444

@@ -450,7 +451,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
450451
# Scale channels by frequency.
451452
h = w * h
452453

453-
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1)
454+
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
454455

455456

456457
class MochiEncoder3D(nn.Module):

0 commit comments

Comments
 (0)