Skip to content

Commit c1340b4

Browse files
committed
Code pathway cleanup
Signed-off-by: Samuel Monson <[email protected]>
1 parent aee230c commit c1340b4

File tree

1 file changed

+116
-51
lines changed

1 file changed

+116
-51
lines changed

src/guidellm/data/utils/functions.py

Lines changed: 116 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from PIL import Image as PILImage
1212
from torch import Tensor
13+
from torchcodec import AudioSamples
1314
from torchcodec.decoders import AudioDecoder
1415
from torchcodec.encoders import AudioEncoder
1516

@@ -251,10 +252,10 @@ def encode_video(
251252
}
252253

253254

254-
def encode_audio(
255-
audio: Any,
255+
def encode_audio( # noqa: C901 # noqa: PLR0913
256+
audio: AudioDecoder | bytes | str | Path | np.ndarray | Tensor | dict[str, Any],
256257
b64encode: bool = False,
257-
sample_rate: int = 16000,
258+
sample_rate: int | None = None,
258259
file_name: str = "audio.wav",
259260
encode_sample_rate: int = 16000,
260261
max_duration: float | None = None,
@@ -273,90 +274,154 @@ def encode_audio(
273274
],
274275
str | int | float | None,
275276
]:
277+
"""Decode audio (if nessary) and re-encode to specified format."""
278+
samples = _decode_audio(audio, sample_rate=sample_rate, max_duration=max_duration)
279+
280+
bitrate_val = (
281+
int(bitrate.rstrip("k")) * 1000 if bitrate.endswith("k") else int(bitrate)
282+
)
283+
format_val = audio_format.lower()
284+
285+
encoded_audio = _encode_audio(
286+
samples=samples,
287+
resample_rate=encode_sample_rate,
288+
bitrate=bitrate_val,
289+
audio_format=format_val,
290+
mono=mono,
291+
)
292+
293+
return {
294+
"type": "audio_base64" if b64encode else "audio_file",
295+
"audio": (
296+
base64.b64encode(encoded_audio).decode("utf-8")
297+
if b64encode
298+
else encoded_audio
299+
),
300+
"file_name": get_file_name(audio)
301+
if isinstance(audio, str | Path)
302+
else file_name,
303+
"format": audio_format,
304+
"mimetype": f"audio/{format_val}",
305+
"audio_samples": samples.sample_rate,
306+
"audio_seconds": samples.duration_seconds,
307+
"audio_bytes": len(encoded_audio),
308+
}
309+
310+
311+
def _decode_audio( # noqa: C901, PLR0912
312+
audio: AudioDecoder | bytes | str | Path | np.ndarray | Tensor | dict[str, Any],
313+
sample_rate: int | None = None,
314+
max_duration: float | None = None,
315+
) -> AudioSamples:
316+
"""Decode audio from various input types into AudioSamples."""
317+
# If input is a dict, unwrap it into a function call
276318
if isinstance(audio, dict):
277319
sample_rate = audio.get("sample_rate", audio.get("sampling_rate", sample_rate))
278320
if "data" not in audio and "url" not in audio:
279321
raise ValueError(
280322
f"Audio dict must contain either 'data' or 'url' keys, got {audio}"
281323
)
282-
return encode_audio(
324+
return _decode_audio(
283325
audio=audio.get("data") or audio.get("url"),
284326
sample_rate=sample_rate,
285-
encode_sample_rate=encode_sample_rate,
286327
max_duration=max_duration,
287-
mono=mono,
288-
audio_format=audio_format,
289-
bitrate=bitrate,
290328
)
291329

292-
decoder: AudioDecoder
330+
# Convert numpy array to torch tensor and re-call
331+
if isinstance(audio, np.ndarray):
332+
return _decode_audio(
333+
audio=torch.from_numpy(audio),
334+
sample_rate=sample_rate,
335+
max_duration=max_duration,
336+
)
337+
338+
samples: AudioSamples
293339

340+
# HF datasets return AudioDecoder for audio column
294341
if isinstance(audio, AudioDecoder):
295-
decoder = audio
296-
elif isinstance(audio, Tensor | bytes):
342+
samples = audio.get_samples_played_in_range(stop_seconds=max_duration)
343+
344+
elif isinstance(audio, Tensor):
345+
# If float stream assume decoded audio
346+
if torch.is_floating_point(audio):
347+
if sample_rate is None:
348+
raise ValueError("Sample rate must be set for decoded audio")
349+
350+
full_duration = audio.shape[1] / sample_rate
351+
# If max_duration is set, trim the audio to that duration
352+
if max_duration is not None:
353+
num_samples = int(max_duration * sample_rate)
354+
duration = min(max_duration, full_duration)
355+
data = audio[:, :num_samples]
356+
else:
357+
duration = full_duration
358+
data = audio
359+
360+
samples = AudioSamples(
361+
data=data,
362+
pts_seconds=0.0,
363+
duration_seconds=duration,
364+
sample_rate=sample_rate,
365+
)
366+
# If bytes tensor assume encoded audio
367+
elif audio.dtype == torch.uint8:
368+
decoder = AudioDecoder(
369+
source=audio,
370+
sample_rate=sample_rate,
371+
)
372+
samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
373+
374+
else:
375+
raise ValueError(f"Unsupported audio type: {type(audio)}")
376+
377+
# If bytes, assume encoded audio
378+
elif isinstance(audio, bytes):
297379
decoder = AudioDecoder(
298380
source=audio,
299381
sample_rate=sample_rate,
300382
)
383+
samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
384+
385+
# If str or Path, assume file path or URL to encoded audio
301386
elif isinstance(audio, str | Path):
302-
if is_url(audio):
387+
if isinstance(audio, str) and is_url(audio):
303388
response = httpx.get(audio)
304389
response.raise_for_status()
305-
file_name = get_file_name(audio)
306-
decoder = AudioDecoder(
307-
source=response.content,
308-
)
390+
data = response.content
309391
else:
310392
if not Path(audio).exists():
311393
raise ValueError(f"Audio file does not exist: {audio}")
312-
file_name = get_file_name(audio)
313-
decoder = AudioDecoder(
314-
source=audio,
315-
)
316-
elif isinstance(audio, np.ndarray):
317-
# AudioDecoder really needs a from_raw method
318-
pre_encoder = AudioEncoder(
319-
samples=torch.from_numpy(audio),
320-
sample_rate=sample_rate,
394+
data = Path(audio).read_bytes()
395+
decoder = AudioDecoder(
396+
source=data,
321397
)
322-
decoder = AudioDecoder(source=pre_encoder.to_tensor(format="wav"))
398+
samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
323399
else:
324400
raise ValueError(f"Unsupported audio type: {type(audio)}")
325401

326-
samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
402+
return samples
403+
404+
405+
def _encode_audio(
406+
samples: AudioSamples,
407+
resample_rate: int | None = None,
408+
bitrate: int = 64000,
409+
audio_format: str = "mp3",
410+
mono: bool = True,
411+
) -> bytes:
327412
encoder = AudioEncoder(
328413
samples=samples.data,
329414
sample_rate=samples.sample_rate,
330415
)
331416

332-
bit_rate_val = (
333-
int(bitrate.rstrip("k")) * 1000 if bitrate.endswith("k") else int(bitrate)
334-
)
335-
format_val = audio_format.lower()
336-
337417
audio_tensor = encoder.to_tensor(
338-
format=format_val,
339-
bit_rate=bit_rate_val if format_val == "mp3" else None,
418+
format=audio_format,
419+
bit_rate=bitrate if audio_format == "mp3" else None,
340420
num_channels=1 if mono else None,
341-
sample_rate=encode_sample_rate if sample_rate != encode_sample_rate else None,
421+
sample_rate=resample_rate,
342422
)
343423

344-
encoded_audio = audio_tensor.numpy().tobytes()
345-
346-
return {
347-
"type": "audio_base64" if b64encode else "audio_file",
348-
"audio": (
349-
base64.b64encode(encoded_audio).decode("utf-8")
350-
if b64encode
351-
else encoded_audio
352-
),
353-
"file_name": file_name,
354-
"format": audio_format,
355-
"mimetype": f"audio/{format_val}",
356-
"audio_samples": samples.sample_rate,
357-
"audio_seconds": samples.duration_seconds,
358-
"audio_bytes": len(encoded_audio),
359-
}
424+
return audio_tensor.numpy().tobytes()
360425

361426

362427
def get_file_name(path: Path | str) -> str:

0 commit comments

Comments
 (0)