22from typing import List , Tuple
33
44import torch
5+ from huggingface_hub import snapshot_download
6+ from huggingface_hub .errors import LocalEntryNotFoundError
57from pyannote .audio import Model , Pipeline
68from pyannote .audio .core .task import Problem , Resolution , Specifications
79from pyannote .audio .pipelines import VoiceActivityDetection
1214_PIPELINE = None
1315
1416
15- def get_pipeline ( device : torch . device ) -> Pipeline :
17+ def resolve_local_segmentation_path ( model_id : str ) -> str :
1618 """
17- Retrieves a PyAnnote voice activity detection pipeline and move it to the specified device.
18- The pipeline is loaded only once and reused across subsequent calls.
19- It requires the Hugging Face API token to be set in the HF_TOKEN environment variable.
19+ Finds the local path to the segmentation model.
2020 """
21- global _PIPELINE
22- if _PIPELINE is not None :
23- return _PIPELINE .to (device )
24-
2521 try :
26- hf_token = os .environ ["HF_TOKEN" ]
27- except KeyError as exc :
28- raise ValueError ("HF_TOKEN environment variable is not set" ) from exc
22+ return snapshot_download (
23+ repo_id = model_id ,
24+ local_files_only = True ,
25+ )
26+ except LocalEntryNotFoundError :
27+ pass
28+
29+ hf_token = os .getenv ("HF_TOKEN" )
30+ if not hf_token :
31+ raise RuntimeError (
32+ f"Model { model_id } was not found locally, "
33+ f"and no HF_TOKEN was provided to download it."
34+ )
35+
36+ return snapshot_download (
37+ repo_id = model_id ,
38+ token = hf_token ,
39+ )
40+
41+
42+ def load_segmentation_model (model_id : str ) -> Model :
43+ """
44+ Loads the segmentation model from a local snapshot.
45+ If it doesn’t exist, it first creates (downloads) the snapshot.
46+ """
47+ local_path = resolve_local_segmentation_path (model_id = model_id )
2948
3049 with torch .serialization .safe_globals (
3150 [
@@ -35,7 +54,23 @@ def get_pipeline(device: torch.device) -> Pipeline:
3554 Resolution ,
3655 ]
3756 ):
38- model = Model .from_pretrained ("pyannote/segmentation-3.0" , token = hf_token )
57+ return Model .from_pretrained (local_path )
58+
59+
60+ def get_pipeline (
61+ device : torch .device , model_id : str = "pyannote/segmentation-3.0"
62+ ) -> Pipeline :
63+ """
64+ Retrieves a PyAnnote voice activity detection pipeline and moves it to the specified device.
65+ The pipeline is loaded only once and reused across subsequent calls.
66+ It requires the Hugging Face API token to be set in the HF_TOKEN environment variable.
67+ """
68+ global _PIPELINE
69+ if _PIPELINE is not None :
70+ return _PIPELINE .to (device )
71+
72+ model = load_segmentation_model (model_id = model_id )
73+
3974 _PIPELINE = VoiceActivityDetection (segmentation = model )
4075 _PIPELINE .instantiate ({"min_duration_on" : 0.0 , "min_duration_off" : 0.0 })
4176
0 commit comments