1010import torch
1111from PIL import Image as PILImage
1212from torch import Tensor
13+ from torchcodec import AudioSamples
1314from torchcodec .decoders import AudioDecoder
1415from torchcodec .encoders import AudioEncoder
1516
@@ -251,10 +252,10 @@ def encode_video(
251252 }
252253
253254
254- def encode_audio (
255- audio : Any ,
255+ def encode_audio ( # noqa: C901 # noqa: PLR0913
256+ audio : AudioDecoder | bytes | str | Path | np . ndarray | Tensor | dict [ str , Any ] ,
256257 b64encode : bool = False ,
257- sample_rate : int = 16000 ,
258+ sample_rate : int | None = None ,
258259 file_name : str = "audio.wav" ,
259260 encode_sample_rate : int = 16000 ,
260261 max_duration : float | None = None ,
@@ -273,90 +274,154 @@ def encode_audio(
273274 ],
274275 str | int | float | None ,
275276]:
277+ """Decode audio (if nessary) and re-encode to specified format."""
278+ samples = _decode_audio (audio , sample_rate = sample_rate , max_duration = max_duration )
279+
280+ bitrate_val = (
281+ int (bitrate .rstrip ("k" )) * 1000 if bitrate .endswith ("k" ) else int (bitrate )
282+ )
283+ format_val = audio_format .lower ()
284+
285+ encoded_audio = _encode_audio (
286+ samples = samples ,
287+ resample_rate = encode_sample_rate ,
288+ bitrate = bitrate_val ,
289+ audio_format = format_val ,
290+ mono = mono ,
291+ )
292+
293+ return {
294+ "type" : "audio_base64" if b64encode else "audio_file" ,
295+ "audio" : (
296+ base64 .b64encode (encoded_audio ).decode ("utf-8" )
297+ if b64encode
298+ else encoded_audio
299+ ),
300+ "file_name" : get_file_name (audio )
301+ if isinstance (audio , str | Path )
302+ else file_name ,
303+ "format" : audio_format ,
304+ "mimetype" : f"audio/{ format_val } " ,
305+ "audio_samples" : samples .sample_rate ,
306+ "audio_seconds" : samples .duration_seconds ,
307+ "audio_bytes" : len (encoded_audio ),
308+ }
309+
310+
311+ def _decode_audio ( # noqa: C901, PLR0912
312+ audio : AudioDecoder | bytes | str | Path | np .ndarray | Tensor | dict [str , Any ],
313+ sample_rate : int | None = None ,
314+ max_duration : float | None = None ,
315+ ) -> AudioSamples :
316+ """Decode audio from various input types into AudioSamples."""
317+ # If input is a dict, unwrap it into a function call
276318 if isinstance (audio , dict ):
277319 sample_rate = audio .get ("sample_rate" , audio .get ("sampling_rate" , sample_rate ))
278320 if "data" not in audio and "url" not in audio :
279321 raise ValueError (
280322 f"Audio dict must contain either 'data' or 'url' keys, got { audio } "
281323 )
282- return encode_audio (
324+ return _decode_audio (
283325 audio = audio .get ("data" ) or audio .get ("url" ),
284326 sample_rate = sample_rate ,
285- encode_sample_rate = encode_sample_rate ,
286327 max_duration = max_duration ,
287- mono = mono ,
288- audio_format = audio_format ,
289- bitrate = bitrate ,
290328 )
291329
292- decoder : AudioDecoder
330+ # Convert numpy array to torch tensor and re-call
331+ if isinstance (audio , np .ndarray ):
332+ return _decode_audio (
333+ audio = torch .from_numpy (audio ),
334+ sample_rate = sample_rate ,
335+ max_duration = max_duration ,
336+ )
337+
338+ samples : AudioSamples
293339
340+ # HF datasets return AudioDecoder for audio column
294341 if isinstance (audio , AudioDecoder ):
295- decoder = audio
296- elif isinstance (audio , Tensor | bytes ):
342+ samples = audio .get_samples_played_in_range (stop_seconds = max_duration )
343+
344+ elif isinstance (audio , Tensor ):
345+ # If float stream assume decoded audio
346+ if torch .is_floating_point (audio ):
347+ if sample_rate is None :
348+ raise ValueError ("Sample rate must be set for decoded audio" )
349+
350+ full_duration = audio .shape [1 ] / sample_rate
351+ # If max_duration is set, trim the audio to that duration
352+ if max_duration is not None :
353+ num_samples = int (max_duration * sample_rate )
354+ duration = min (max_duration , full_duration )
355+ data = audio [:, :num_samples ]
356+ else :
357+ duration = full_duration
358+ data = audio
359+
360+ samples = AudioSamples (
361+ data = data ,
362+ pts_seconds = 0.0 ,
363+ duration_seconds = duration ,
364+ sample_rate = sample_rate ,
365+ )
366+ # If bytes tensor assume encoded audio
367+ elif audio .dtype == torch .uint8 :
368+ decoder = AudioDecoder (
369+ source = audio ,
370+ sample_rate = sample_rate ,
371+ )
372+ samples = decoder .get_samples_played_in_range (stop_seconds = max_duration )
373+
374+ else :
375+ raise ValueError (f"Unsupported audio type: { type (audio )} " )
376+
377+ # If bytes, assume encoded audio
378+ elif isinstance (audio , bytes ):
297379 decoder = AudioDecoder (
298380 source = audio ,
299381 sample_rate = sample_rate ,
300382 )
383+ samples = decoder .get_samples_played_in_range (stop_seconds = max_duration )
384+
385+ # If str or Path, assume file path or URL to encoded audio
301386 elif isinstance (audio , str | Path ):
302- if is_url (audio ):
387+ if isinstance ( audio , str ) and is_url (audio ):
303388 response = httpx .get (audio )
304389 response .raise_for_status ()
305- file_name = get_file_name (audio )
306- decoder = AudioDecoder (
307- source = response .content ,
308- )
390+ data = response .content
309391 else :
310392 if not Path (audio ).exists ():
311393 raise ValueError (f"Audio file does not exist: { audio } " )
312- file_name = get_file_name (audio )
313- decoder = AudioDecoder (
314- source = audio ,
315- )
316- elif isinstance (audio , np .ndarray ):
317- # AudioDecoder really needs a from_raw method
318- pre_encoder = AudioEncoder (
319- samples = torch .from_numpy (audio ),
320- sample_rate = sample_rate ,
394+ data = Path (audio ).read_bytes ()
395+ decoder = AudioDecoder (
396+ source = data ,
321397 )
322- decoder = AudioDecoder ( source = pre_encoder . to_tensor ( format = "wav" ) )
398+ samples = decoder . get_samples_played_in_range ( stop_seconds = max_duration )
323399 else :
324400 raise ValueError (f"Unsupported audio type: { type (audio )} " )
325401
326- samples = decoder .get_samples_played_in_range (stop_seconds = max_duration )
402+ return samples
403+
404+
405+ def _encode_audio (
406+ samples : AudioSamples ,
407+ resample_rate : int | None = None ,
408+ bitrate : int = 64000 ,
409+ audio_format : str = "mp3" ,
410+ mono : bool = True ,
411+ ) -> bytes :
327412 encoder = AudioEncoder (
328413 samples = samples .data ,
329414 sample_rate = samples .sample_rate ,
330415 )
331416
332- bit_rate_val = (
333- int (bitrate .rstrip ("k" )) * 1000 if bitrate .endswith ("k" ) else int (bitrate )
334- )
335- format_val = audio_format .lower ()
336-
337417 audio_tensor = encoder .to_tensor (
338- format = format_val ,
339- bit_rate = bit_rate_val if format_val == "mp3" else None ,
418+ format = audio_format ,
419+ bit_rate = bitrate if audio_format == "mp3" else None ,
340420 num_channels = 1 if mono else None ,
341- sample_rate = encode_sample_rate if sample_rate != encode_sample_rate else None ,
421+ sample_rate = resample_rate ,
342422 )
343423
344- encoded_audio = audio_tensor .numpy ().tobytes ()
345-
346- return {
347- "type" : "audio_base64" if b64encode else "audio_file" ,
348- "audio" : (
349- base64 .b64encode (encoded_audio ).decode ("utf-8" )
350- if b64encode
351- else encoded_audio
352- ),
353- "file_name" : file_name ,
354- "format" : audio_format ,
355- "mimetype" : f"audio/{ format_val } " ,
356- "audio_samples" : samples .sample_rate ,
357- "audio_seconds" : samples .duration_seconds ,
358- "audio_bytes" : len (encoded_audio ),
359- }
424+ return audio_tensor .numpy ().tobytes ()
360425
361426
362427def get_file_name (path : Path | str ) -> str :
0 commit comments