|
20 | 20 | from transformers import ( |
21 | 21 | ClapFeatureExtractor, |
22 | 22 | ClapModel, |
23 | | - GPT2Model, |
| 23 | + GPT2LMHeadModel, |
24 | 24 | RobertaTokenizer, |
25 | 25 | RobertaTokenizerFast, |
26 | 26 | SpeechT5HifiGan, |
@@ -196,7 +196,7 @@ def __init__( |
196 | 196 | text_encoder: ClapModel, |
197 | 197 | text_encoder_2: Union[T5EncoderModel, VitsModel], |
198 | 198 | projection_model: AudioLDM2ProjectionModel, |
199 | | - language_model: GPT2Model, |
| 199 | + language_model: GPT2LMHeadModel, |
200 | 200 | tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], |
201 | 201 | tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer], |
202 | 202 | feature_extractor: ClapFeatureExtractor, |
@@ -259,7 +259,10 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t |
259 | 259 | ) |
260 | 260 |
|
261 | 261 | device_type = torch_device.type |
262 | | - device = torch.device(f"{device_type}:{gpu_id or torch_device.index}") |
| 262 | + device_str = device_type |
| 263 | + if gpu_id or torch_device.index: |
| 264 | + device_str = f"{device_str}:{gpu_id or torch_device.index}" |
| 265 | + device = torch.device(device_str) |
263 | 266 |
|
264 | 267 | if self.device.type != "cpu": |
265 | 268 | self.to("cpu", silence_dtype_warnings=True) |
@@ -316,9 +319,9 @@ def generate_language_model( |
316 | 319 | model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) |
317 | 320 |
|
318 | 321 | # forward pass to get next hidden states |
319 | | - output = self.language_model(**model_inputs, return_dict=True) |
| 322 | + output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True) |
320 | 323 |
|
321 | | - next_hidden_states = output.last_hidden_state |
| 324 | + next_hidden_states = output.hidden_states[-1] |
322 | 325 |
|
323 | 326 | # Update the model input |
324 | 327 | inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1) |
|
0 commit comments