Skip to content

Commit 8e78880

Browse files
Merge pull request #54 from salute-developers/hf_local_cache
No need to repass HF_TOKEN, explicit check for a local copy
2 parents e49565e + c73b962 commit 8e78880

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

gigaam/vad_utils.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import List, Tuple
33

44
import torch
5+
from huggingface_hub import snapshot_download
6+
from huggingface_hub.errors import LocalEntryNotFoundError
57
from pyannote.audio import Model, Pipeline
68
from pyannote.audio.core.task import Problem, Resolution, Specifications
79
from pyannote.audio.pipelines import VoiceActivityDetection
@@ -12,20 +14,37 @@
1214
_PIPELINE = None
1315

1416

15-
def get_pipeline(device: torch.device) -> Pipeline:
17+
def resolve_local_segmentation_path(model_id: str) -> str:
1618
"""
17-
Retrieves a PyAnnote voice activity detection pipeline and move it to the specified device.
18-
The pipeline is loaded only once and reused across subsequent calls.
19-
It requires the Hugging Face API token to be set in the HF_TOKEN environment variable.
19+
Finds the local path to the segmentation model.
2020
"""
21-
global _PIPELINE
22-
if _PIPELINE is not None:
23-
return _PIPELINE.to(device)
24-
2521
try:
26-
hf_token = os.environ["HF_TOKEN"]
27-
except KeyError as exc:
28-
raise ValueError("HF_TOKEN environment variable is not set") from exc
22+
return snapshot_download(
23+
repo_id=model_id,
24+
local_files_only=True,
25+
)
26+
except LocalEntryNotFoundError:
27+
pass
28+
29+
hf_token = os.getenv("HF_TOKEN")
30+
if not hf_token:
31+
raise RuntimeError(
32+
f"Model {model_id} was not found locally, "
33+
f"and no HF_TOKEN was provided to download it."
34+
)
35+
36+
return snapshot_download(
37+
repo_id=model_id,
38+
token=hf_token,
39+
)
40+
41+
42+
def load_segmentation_model(model_id: str) -> Model:
43+
"""
44+
Loads the segmentation model from a local snapshot.
45+
If it doesn’t exist, it first creates (downloads) the snapshot.
46+
"""
47+
local_path = resolve_local_segmentation_path(model_id=model_id)
2948

3049
with torch.serialization.safe_globals(
3150
[
@@ -35,7 +54,23 @@ def get_pipeline(device: torch.device) -> Pipeline:
3554
Resolution,
3655
]
3756
):
38-
model = Model.from_pretrained("pyannote/segmentation-3.0", token=hf_token)
57+
return Model.from_pretrained(local_path)
58+
59+
60+
def get_pipeline(
61+
device: torch.device, model_id: str = "pyannote/segmentation-3.0"
62+
) -> Pipeline:
63+
"""
64+
Retrieves a PyAnnote voice activity detection pipeline and moves it to the specified device.
65+
The pipeline is loaded only once and reused across subsequent calls.
66+
It requires the Hugging Face API token to be set in the HF_TOKEN environment variable.
67+
"""
68+
global _PIPELINE
69+
if _PIPELINE is not None:
70+
return _PIPELINE.to(device)
71+
72+
model = load_segmentation_model(model_id=model_id)
73+
3974
_PIPELINE = VoiceActivityDetection(segmentation=model)
4075
_PIPELINE.instantiate({"min_duration_on": 0.0, "min_duration_off": 0.0})
4176

0 commit comments

Comments
 (0)