Skip to content

Commit 73a82b9

Browse files
committed
fix auto device detection in voxtral
1 parent 97eb45e commit 73a82b9

File tree

2 files changed

+47
-62
lines changed

2 files changed

+47
-62
lines changed

internal/transcription/adapters/py/voxtral/voxtral_transcribe.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ def transcribe_audio(
3535
Dictionary containing transcription results
3636
"""
3737
# Determine device
38-
if device == "auto":
39-
device = "cuda" if torch.cuda.is_available() else "cpu"
38+
# if device == "auto":
39+
# device = "cuda" if torch.cuda.is_available() else "cpu"
40+
device = "cuda" if torch.cuda.is_available() else "cpu"
4041

4142
print(f"Loading Voxtral model on {device}...", file=sys.stderr)
4243

@@ -57,9 +58,7 @@ def transcribe_audio(
5758

5859
# Prepare transcription request using the proper method
5960
inputs = processor.apply_transcription_request(
60-
language=language,
61-
audio=audio_path,
62-
model_id=model_id
61+
language=language, audio=audio_path, model_id=model_id
6362
)
6463

6564
# Move inputs to device with correct dtype
@@ -76,8 +75,7 @@ def transcribe_audio(
7675

7776
# Decode only the newly generated tokens (skip the input prompt)
7877
decoded_outputs = processor.batch_decode(
79-
outputs[:, inputs.input_ids.shape[1]:],
80-
skip_special_tokens=True
78+
outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True
8179
)
8280

8381
transcription_text = decoded_outputs[0]
@@ -94,7 +92,7 @@ def transcribe_audio(
9492
"start": 0.0,
9593
"end": 0.0, # Duration unknown without audio analysis
9694
"text": transcription_text,
97-
"words": [] # Voxtral doesn't provide word-level timestamps
95+
"words": [], # Voxtral doesn't provide word-level timestamps
9896
}
9997
],
10098
"language": language,
@@ -104,7 +102,7 @@ def transcribe_audio(
104102

105103
# Write output
106104
output_file = Path(output_path)
107-
with output_file.open('w', encoding='utf-8') as f:
105+
with output_file.open("w", encoding="utf-8") as f:
108106
json.dump(result, f, ensure_ascii=False, indent=2)
109107

110108
print(f"Results written to {output_path}", file=sys.stderr)
@@ -116,40 +114,29 @@ def main():
116114
parser = argparse.ArgumentParser(
117115
description="Transcribe audio using Voxtral-mini model"
118116
)
117+
parser.add_argument("audio_path", type=str, help="Path to input audio file")
118+
parser.add_argument("output_path", type=str, help="Path to output JSON file")
119119
parser.add_argument(
120-
"audio_path",
121-
type=str,
122-
help="Path to input audio file"
123-
)
124-
parser.add_argument(
125-
"output_path",
126-
type=str,
127-
help="Path to output JSON file"
128-
)
129-
parser.add_argument(
130-
"--language",
131-
type=str,
132-
default="en",
133-
help="Language code (default: en)"
120+
"--language", type=str, default="en", help="Language code (default: en)"
134121
)
135122
parser.add_argument(
136123
"--model-id",
137124
type=str,
138125
default="mistralai/Voxtral-mini",
139-
help="HuggingFace model ID (default: mistralai/Voxtral-mini)"
126+
help="HuggingFace model ID (default: mistralai/Voxtral-mini)",
140127
)
141128
parser.add_argument(
142129
"--device",
143130
type=str,
144131
default="auto",
145132
choices=["cpu", "cuda", "auto"],
146-
help="Device to use (default: auto)"
133+
help="Device to use (default: auto)",
147134
)
148135
parser.add_argument(
149136
"--max-new-tokens",
150137
type=int,
151138
default=8192,
152-
help="Maximum number of tokens to generate (default: 8192)"
139+
help="Maximum number of tokens to generate (default: 8192)",
153140
)
154141

155142
args = parser.parse_args()
@@ -166,6 +153,7 @@ def main():
166153
except Exception as e:
167154
print(f"Error: {e}", file=sys.stderr)
168155
import traceback
156+
169157
traceback.print_exc(file=sys.stderr)
170158
sys.exit(1)
171159

internal/transcription/adapters/py/voxtral/voxtral_transcribe_buffered.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ def split_audio_file(audio_path, chunk_duration_secs=1500):
2929
end_sample = min(start_sample + chunk_samples, len(audio))
3030
chunk_audio = audio[start_sample:end_sample]
3131
start_time = start_sample / sr
32-
chunks.append({
33-
'audio': chunk_audio,
34-
'start_time': start_time,
35-
'duration': len(chunk_audio) / sr
36-
})
32+
chunks.append(
33+
{
34+
"audio": chunk_audio,
35+
"start_time": start_time,
36+
"duration": len(chunk_audio) / sr,
37+
}
38+
)
3739

3840
return chunks, sr
3941

@@ -51,8 +53,9 @@ def transcribe_buffered(
5153
Transcribe long audio by splitting into chunks and merging results.
5254
"""
5355
# Determine device
54-
if device == "auto":
55-
device = "cuda" if torch.cuda.is_available() else "cpu"
56+
# if device == "auto":
57+
# device = "cuda" if torch.cuda.is_available() else "cpu"
58+
device = "cuda" if torch.cuda.is_available() else "cpu"
5659

5760
print(f"Loading Voxtral model on {device}...", file=sys.stderr)
5861

@@ -77,18 +80,19 @@ def transcribe_buffered(
7780
full_text = []
7881

7982
for i, chunk_info in enumerate(chunks):
80-
print(f"Transcribing chunk {i+1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...", file=sys.stderr)
83+
print(
84+
f"Transcribing chunk {i + 1}/{len(chunks)} (duration: {chunk_info['duration']:.1f}s)...",
85+
file=sys.stderr,
86+
)
8187

8288
# Save chunk to temporary file
8389
chunk_path = f"/tmp/voxtral_chunk_{i}.wav"
84-
sf.write(chunk_path, chunk_info['audio'], sr)
90+
sf.write(chunk_path, chunk_info["audio"], sr)
8591

8692
try:
8793
# Prepare transcription request for this chunk
8894
inputs = processor.apply_transcription_request(
89-
language=language,
90-
audio=chunk_path,
91-
model_id=model_id
95+
language=language, audio=chunk_path, model_id=model_id
9296
)
9397

9498
# Move inputs to device with correct dtype
@@ -103,14 +107,15 @@ def transcribe_buffered(
103107

104108
# Decode only the newly generated tokens (skip the input prompt)
105109
decoded_outputs = processor.batch_decode(
106-
outputs[:, inputs.input_ids.shape[1]:],
107-
skip_special_tokens=True
110+
outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True
108111
)
109112

110113
chunk_text = decoded_outputs[0]
111114
full_text.append(chunk_text)
112115

113-
print(f"Chunk {i+1} complete: {len(chunk_text)} characters", file=sys.stderr)
116+
print(
117+
f"Chunk {i + 1} complete: {len(chunk_text)} characters", file=sys.stderr
118+
)
114119

115120
finally:
116121
# Clean up temp file
@@ -119,7 +124,9 @@ def transcribe_buffered(
119124

120125
# Concatenate all chunks
121126
final_text = " ".join(full_text)
122-
print(f"Transcription complete: {len(final_text)} characters total", file=sys.stderr)
127+
print(
128+
f"Transcription complete: {len(final_text)} characters total", file=sys.stderr
129+
)
123130

124131
# Prepare output in Scriberr format
125132
# Note: Voxtral doesn't provide word-level timestamps
@@ -131,7 +138,7 @@ def transcribe_buffered(
131138
"start": 0.0,
132139
"end": 0.0, # Duration unknown without audio analysis
133140
"text": final_text,
134-
"words": [] # Voxtral doesn't provide word-level timestamps
141+
"words": [], # Voxtral doesn't provide word-level timestamps
135142
}
136143
],
137144
"language": language,
@@ -144,7 +151,7 @@ def transcribe_buffered(
144151

145152
# Write output
146153
output_file_path = Path(output_file)
147-
with output_file_path.open('w', encoding='utf-8') as f:
154+
with output_file_path.open("w", encoding="utf-8") as f:
148155
json.dump(result, f, ensure_ascii=False, indent=2)
149156

150157
print(f"Results written to {output_file}", file=sys.stderr)
@@ -156,46 +163,35 @@ def main():
156163
parser = argparse.ArgumentParser(
157164
description="Transcribe long audio using Voxtral with chunking"
158165
)
166+
parser.add_argument("audio_path", type=str, help="Path to input audio file")
167+
parser.add_argument("output_path", type=str, help="Path to output JSON file")
159168
parser.add_argument(
160-
"audio_path",
161-
type=str,
162-
help="Path to input audio file"
163-
)
164-
parser.add_argument(
165-
"output_path",
166-
type=str,
167-
help="Path to output JSON file"
168-
)
169-
parser.add_argument(
170-
"--language",
171-
type=str,
172-
default="en",
173-
help="Language code (default: en)"
169+
"--language", type=str, default="en", help="Language code (default: en)"
174170
)
175171
parser.add_argument(
176172
"--model-id",
177173
type=str,
178174
default="mistralai/Voxtral-mini",
179-
help="HuggingFace model ID (default: mistralai/Voxtral-mini)"
175+
help="HuggingFace model ID (default: mistralai/Voxtral-mini)",
180176
)
181177
parser.add_argument(
182178
"--device",
183179
type=str,
184180
default="auto",
185181
choices=["cpu", "cuda", "auto"],
186-
help="Device to use (default: auto)"
182+
help="Device to use (default: auto)",
187183
)
188184
parser.add_argument(
189185
"--max-new-tokens",
190186
type=int,
191187
default=8192,
192-
help="Maximum number of tokens to generate per chunk (default: 8192)"
188+
help="Maximum number of tokens to generate per chunk (default: 8192)",
193189
)
194190
parser.add_argument(
195191
"--chunk-len",
196192
type=float,
197193
default=1500,
198-
help="Chunk duration in seconds (default: 1500 = 25 minutes)"
194+
help="Chunk duration in seconds (default: 1500 = 25 minutes)",
199195
)
200196

201197
args = parser.parse_args()
@@ -217,6 +213,7 @@ def main():
217213
except Exception as e:
218214
print(f"Error: {e}", file=sys.stderr)
219215
import traceback
216+
220217
traceback.print_exc(file=sys.stderr)
221218
sys.exit(1)
222219

0 commit comments

Comments
 (0)