Skip to content

Commit 218e875

Browse files
authored
BUG: fix IndexTTS2 on transformes 4.57.1 (#4158)
1 parent 06133fd commit 218e875

File tree

2 files changed

+122
-6
lines changed

2 files changed

+122
-6
lines changed

xinference/thirdparty/indextts/gpt/transformers_generation_utils.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@
3030
DynamicCache,
3131
EncoderDecoderCache,
3232
OffloadedCache,
33-
QuantizedCacheConfig,
33+
QuantizedCache,
3434
StaticCache,
35+
SlidingWindowCache,
36+
SinkCache,
37+
HybridCache,
38+
HybridChunkedCache,
3539
)
3640
from transformers.configuration_utils import PretrainedConfig
3741
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -55,13 +59,10 @@
5559
AssistedCandidateGeneratorDifferentTokenizers,
5660
CandidateGenerator,
5761
PromptLookupCandidateGenerator,
58-
_crop_past_key_values,
5962
_prepare_attention_mask,
6063
_prepare_token_type_ids,
6164
)
6265
from transformers.generation.configuration_utils import (
63-
NEED_SETUP_CACHE_CLASSES_MAPPING,
64-
QUANT_BACKEND_CLASSES_MAPPING,
6566
GenerationConfig,
6667
GenerationMode,
6768
)
@@ -111,6 +112,70 @@
111112

112113
logger = logging.get_logger(__name__)
113114

115+
# Compatibility with transformers 4.57.1+
116+
# These mappings are needed for the removed constants
117+
NEED_SETUP_CACHE_CLASSES_MAPPING = {
118+
"auto": Cache,
119+
"dynamic": DynamicCache,
120+
"static": StaticCache,
121+
"offloaded": OffloadedCache,
122+
"sliding_window": SlidingWindowCache,
123+
"sink": SinkCache,
124+
"hybrid": HybridCache,
125+
"hybrid_chunked": HybridChunkedCache,
126+
}
127+
128+
# Mapping for quantized cache backends
129+
QUANT_BACKEND_CLASSES_MAPPING = {
130+
"quanto": QuantizedCache,
131+
"hqq": QuantizedCache,
132+
}
133+
134+
# Compatibility class for removed QuantizedCacheConfig
135+
class QuantizedCacheConfig:
136+
def __init__(self, backend: str = "quanto", nbits: int = 4,
137+
axis_key: int = 0, axis_value: int = 0,
138+
q_group_size: int = 64, residual_length: int = 128):
139+
self.backend = backend
140+
self.nbits = nbits
141+
self.axis_key = axis_key
142+
self.axis_value = axis_value
143+
self.q_group_size = q_group_size
144+
self.residual_length = residual_length
145+
146+
# Compatibility function for removed _crop_past_key_values
147+
def _crop_past_key_values(model, past_key_values, max_length):
148+
"""
149+
Crop past key values to a maximum length.
150+
This is a compatibility function for the removed _crop_past_key_values.
151+
"""
152+
if past_key_values is None:
153+
return past_key_values
154+
155+
# If past_key_values is a Cache object
156+
if hasattr(past_key_values, 'crop'):
157+
return past_key_values.crop(max_length)
158+
159+
# If it's a tuple of tensors (legacy format)
160+
if isinstance(past_key_values, tuple):
161+
cropped_past_key_values = []
162+
for layer_past_key_values in past_key_values:
163+
if isinstance(layer_past_key_values, tuple) and len(layer_past_key_values) == 2:
164+
# Standard format: (key, value)
165+
key, value = layer_past_key_values
166+
if key.shape[-2] > max_length:
167+
key = key[..., :max_length, :]
168+
if value.shape[-2] > max_length:
169+
value = value[..., :max_length, :]
170+
cropped_past_key_values.append((key, value))
171+
else:
172+
# Other formats, just append as is
173+
cropped_past_key_values.append(layer_past_key_values)
174+
return tuple(cropped_past_key_values)
175+
176+
# For other cache types, return as is
177+
return past_key_values
178+
114179
if is_accelerate_available():
115180
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
116181

@@ -1002,7 +1067,8 @@ def _get_logits_processor(
10021067
device=device,
10031068
)
10041069
)
1005-
if generation_config.forced_decoder_ids is not None:
1070+
# Compatibility with transformers 4.57.1+: forced_decoder_ids has been removed
1071+
if hasattr(generation_config, 'forced_decoder_ids') and generation_config.forced_decoder_ids is not None:
10061072
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
10071073
raise ValueError(
10081074
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "

xinference/thirdparty/indextts/gpt/transformers_gpt2.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,57 @@
3232

3333
from indextts.gpt.transformers_generation_utils import GenerationMixin
3434
from indextts.gpt.transformers_modeling_utils import PreTrainedModel
35-
from transformers.modeling_utils import SequenceSummary
35+
# SequenceSummary has been removed in transformers 4.57.1+
36+
# Adding compatibility implementation
37+
class SequenceSummary(nn.Module):
38+
"""
39+
Compute a single vector summary of a sequence hidden states.
40+
"""
41+
def __init__(self, config):
42+
super().__init__()
43+
self.summary_type = getattr(config, 'summary_type', 'last')
44+
self.summary_use_proj = getattr(config, 'summary_use_proj', True)
45+
self.summary_activation = getattr(config, 'summary_activation', None)
46+
self.summary_proj_to_labels = getattr(config, 'summary_proj_to_labels', True)
47+
self.summary_first_dropout = getattr(config, 'summary_first_dropout', 0.1)
48+
49+
if self.summary_use_proj:
50+
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
51+
num_classes = config.num_labels
52+
else:
53+
num_classes = config.hidden_size
54+
self.summary = nn.Linear(config.hidden_size, num_classes)
55+
56+
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
57+
self.activation = nn.Tanh()
58+
else:
59+
self.activation = lambda x: x
60+
61+
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
62+
self.dropout = nn.Dropout(config.summary_first_dropout)
63+
else:
64+
self.dropout = lambda x: x
65+
66+
def forward(self, hidden_states, cls_token_index=None):
67+
if self.summary_type == 'last':
68+
output = hidden_states[:, -1]
69+
elif self.summary_type == 'first':
70+
output = hidden_states[:, 0]
71+
elif self.summary_type == 'mean':
72+
output = hidden_states.mean(dim=1)
73+
elif self.summary_type == 'cls_index':
74+
if cls_token_index is None:
75+
raise ValueError("cls_token_index must be specified when summary_type='cls_index'")
76+
batch_size = hidden_states.size(0)
77+
output = hidden_states[batch_size, cls_token_index]
78+
else:
79+
output = hidden_states[:, -1] # fallback to last
80+
81+
output = self.dropout(output)
82+
if self.summary_use_proj:
83+
output = self.summary(output)
84+
output = self.activation(output)
85+
return output
3686

3787
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
3888
from transformers.modeling_outputs import (

0 commit comments

Comments
 (0)