@@ -1091,8 +1091,6 @@ def forward(
10911091 sample_posterior : bool = False ,
10921092 return_dict : bool = True ,
10931093 generator : Optional [torch .Generator ] = None ,
1094- encoder_local_batch_size : int = 2 ,
1095- decoder_local_batch_size : int = 2 ,
10961094 ) -> Union [DecoderOutput , torch .Tensor ]:
10971095 r"""
10981096 Args:
@@ -1103,18 +1101,14 @@ def forward(
11031101 Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
11041102 generator (`torch.Generator`, *optional*):
11051103 PyTorch random number generator.
1106- encoder_local_batch_size (`int`, *optional*, defaults to 2):
1107- Local batch size for the encoder's batch inference.
1108- decoder_local_batch_size (`int`, *optional*, defaults to 2):
1109- Local batch size for the decoder's batch inference.
11101104 """
11111105 x = sample
1112- posterior = self .encode (x , local_batch_size = encoder_local_batch_size ).latent_dist
1106+ posterior = self .encode (x ).latent_dist
11131107 if sample_posterior :
11141108 z = posterior .sample (generator = generator )
11151109 else :
11161110 z = posterior .mode ()
1117- dec = self .decode (z , local_batch_size = decoder_local_batch_size ).sample
1111+ dec = self .decode (z ).sample
11181112
11191113 if not return_dict :
11201114 return (dec ,)
0 commit comments