Skip to content

Commit e1b978f

Browse files
committed
add tests
1 parent 7527624 commit e1b978f

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

modules/jax/testcontainers/whisper-diarization/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import tempfile
33
from typing import Optional
44

5-
from testcontainers.core.container import DockerContainer
6-
from testcontainers.core.waiting_utils import wait_container_is_ready
5+
from core.testcontainers.core.container import DockerContainer
6+
from core.testcontainers.core.waiting_utils import wait_container_is_ready
77
from urllib.error import URLError
88

99
class JAXWhisperDiarizationContainer(DockerContainer):

modules/jax/tests/test_jax.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
from testcontainers.jax import JAXContainer
3+
4+
def test_jax_container():
5+
with JAXContainer() as jax_container:
6+
jax_container.connect()
7+
8+
# Test running a simple JAX computation
9+
result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))")
10+
assert "2" in result.output.decode()
11+
12+
def test_jax_container_gpu_support():
13+
with JAXContainer() as jax_container:
14+
jax_container.connect()
15+
16+
# Test GPU availability
17+
result = jax_container.run_jax_command(
18+
"import jax; print(jax.devices())"
19+
)
20+
assert "gpu" in result.output.decode().lower()
21+
22+
def test_jax_container_jupyter():
23+
with JAXContainer() as jax_container:
24+
jax_container.connect()
25+
26+
jupyter_url = jax_container.get_jupyter_url()
27+
assert jupyter_url.startswith("http://")
28+
assert ":8888" in jupyter_url

modules/jax/tests/test_whisper.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from testcontainers.whisper_jax import WhisperJAXContainer
3+
4+
@pytest.mark.parametrize("model_name", ["openai/whisper-tiny", "openai/whisper-base"])
5+
def test_whisper_jax_container(model_name):
6+
with WhisperJAXContainer(model_name) as whisper:
7+
whisper.connect()
8+
9+
# Test file transcription
10+
result = whisper.transcribe_file("/path/to/test/audio.wav")
11+
assert isinstance(result, dict)
12+
assert 'text' in result
13+
assert isinstance(result['text'], str)
14+
15+
# Test YouTube transcription
16+
result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
17+
assert isinstance(result, dict)
18+
assert 'text' in result
19+
assert isinstance(result['text'], str)
20+
21+
def test_whisper_jax_container_with_timestamps():
22+
with WhisperJAXContainer() as whisper:
23+
whisper.connect()
24+
25+
result = whisper.transcribe_file("/path/to/test/audio.wav", return_timestamps=True)
26+
assert isinstance(result, dict)
27+
assert 'text' in result
28+
assert 'chunks' in result
29+
assert isinstance(result['chunks'], list)
30+
assert all('timestamp' in chunk for chunk in result['chunks'])
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from testcontainers.jax_whisper_diarization import JAXWhisperDiarizationContainer
3+
4+
@pytest.fixture(scope="module")
5+
def hf_token():
6+
return "your_huggingface_token_here" # Replace with a valid token or use an environment variable
7+
8+
def test_jax_whisper_diarization_container(hf_token):
9+
with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization:
10+
whisper_diarization.connect()
11+
12+
# Test file transcription and diarization
13+
result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav")
14+
assert isinstance(result, list)
15+
assert all(isinstance(item, dict) for item in result)
16+
assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result)
17+
18+
# Test YouTube transcription and diarization
19+
result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
20+
assert isinstance(result, list)
21+
assert all(isinstance(item, dict) for item in result)
22+
assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result)
23+
24+
def test_jax_whisper_diarization_container_without_grouping(hf_token):
25+
with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization:
26+
whisper_diarization.connect()
27+
28+
result = whisper_diarization.transcribe_and_diarize_file("/path/to/test/audio.wav", group_by_speaker=False)
29+
assert isinstance(result, list)
30+
assert all(isinstance(item, dict) for item in result)
31+
assert all('speaker' in item and 'text' in item and 'timestamp' in item for item in result)

0 commit comments

Comments
 (0)