Skip to content

Commit 9020086

Browse files
syntaxticsugrhlky
andauthored
[BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (huggingface#10306)
[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float" torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor. in function prepare_latents: audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) ... audio = initial_audio_waveforms.new_zeros(audio_shape) audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float Co-authored-by: hlky <[email protected]>
1 parent c8ee4af commit 9020086

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def prepare_latents(
446446
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
447447
)
448448

449-
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
449+
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
450450
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
451451

452452
# check num_channels

0 commit comments

Comments
 (0)