diff --git a/src/chatterbox/mtl_tts.py b/src/chatterbox/mtl_tts.py index 2c9cf0524..3a0ce887c 100644 --- a/src/chatterbox/mtl_tts.py +++ b/src/chatterbox/mtl_tts.py @@ -191,7 +191,13 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS': return cls(t3, s3gen, ve, tokenizer, device, conds=conds) @classmethod - def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS': + def from_pretrained(cls, device: torch.device, **kwargs) -> 'ChatterboxMultilingualTTS': + + # Check if model_snapshot_path is provided in kwargs + if 'model_snapshot_path' in kwargs: + model_snapshot_path = Path(kwargs['model_snapshot_path']) + if model_snapshot_path.exists(): + return cls.from_local(ckpt_dir=model_snapshot_path, device=device) ckpt_dir = Path( snapshot_download( repo_id=REPO_ID, diff --git a/src/chatterbox/tts.py b/src/chatterbox/tts.py index 4737f1823..d6df8a5fd 100644 --- a/src/chatterbox/tts.py +++ b/src/chatterbox/tts.py @@ -165,7 +165,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS': return cls(t3, s3gen, ve, tokenizer, device, conds=conds) @classmethod - def from_pretrained(cls, device) -> 'ChatterboxTTS': + def from_pretrained(cls, device, **kwargs) -> 'ChatterboxTTS': # Check if MPS is available on macOS if device == "mps" and not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): @@ -173,6 +173,11 @@ def from_pretrained(cls, device) -> 'ChatterboxTTS': else: print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") device = "cpu" + # Check if model_snapshot_path is provided in kwargs + if 'model_snapshot_path' in kwargs: + model_snapshot_path = Path(kwargs['model_snapshot_path']) + if model_snapshot_path.exists(): + return cls.from_local(ckpt_dir=model_snapshot_path, device=device) for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]: local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) diff --git a/src/chatterbox/tts_turbo.py b/src/chatterbox/tts_turbo.py index cb0326005..099faa419 100644 --- a/src/chatterbox/tts_turbo.py +++ b/src/chatterbox/tts_turbo.py @@ -183,7 +183,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS': return cls(t3, s3gen, ve, tokenizer, device, conds=conds) @classmethod - def from_pretrained(cls, device) -> 'ChatterboxTurboTTS': + def from_pretrained(cls, device, **kwargs) -> 'ChatterboxTurboTTS': # Check if MPS is available on macOS if device == "mps" and not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): @@ -192,6 +192,12 @@ def from_pretrained(cls, device) -> 'ChatterboxTurboTTS': print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") device = "cpu" + # Check if model_snapshot_path is provided in kwargs + if 'model_snapshot_path' in kwargs: + model_snapshot_path = Path(kwargs['model_snapshot_path']) + if model_snapshot_path.exists(): + return cls.from_local(ckpt_dir=model_snapshot_path, device=device) + local_path = snapshot_download( repo_id=REPO_ID, token=os.getenv("HF_TOKEN") or True, diff --git a/src/chatterbox/vc.py b/src/chatterbox/vc.py index a9c32ed35..b6d335859 100644 --- a/src/chatterbox/vc.py +++ b/src/chatterbox/vc.py @@ -59,7 +59,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC': return cls(s3gen, device, ref_dict=ref_dict) @classmethod - def from_pretrained(cls, device) -> 'ChatterboxVC': + def from_pretrained(cls, device, **kwargs) -> 'ChatterboxVC': # Check if MPS is available on macOS if device == "mps" and not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): @@ -67,7 +67,13 @@ def from_pretrained(cls, device) -> 'ChatterboxVC': else: print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") device = "cpu" - + + # Check if model_snapshot_path is provided in kwargs + if 'model_snapshot_path' in kwargs: + model_snapshot_path = Path(kwargs['model_snapshot_path']) + if model_snapshot_path.exists(): + return cls.from_local(ckpt_dir=model_snapshot_path, device=device) + for fpath in ["s3gen.safetensors", "conds.pt"]: local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)