diff --git a/requirements.txt b/requirements.txt index 14bca6b..6ee9f25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ celluloid>=0.2.0 -deepspeed>=0.12.4 librosa>=0.10.1 matplotlib>=3.8.1 numpy>=1.26.2 diff --git a/resemble_enhance/denoiser/inference.py b/resemble_enhance/denoiser/inference.py index 78a33b2..18869c2 100644 --- a/resemble_enhance/denoiser/inference.py +++ b/resemble_enhance/denoiser/inference.py @@ -4,7 +4,8 @@ import torch from ..inference import inference -from .train import Denoiser, HParams +from .safetensors_loader import load_denoiser_model, create_default_denoiser +from .hparams import HParams logger = logging.getLogger(__name__) @@ -12,15 +13,8 @@ @cache def load_denoiser(run_dir, device): if run_dir is None: - return Denoiser(HParams()) - hp = HParams.load(run_dir) - denoiser = Denoiser(hp) - path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" - state_dict = torch.load(path, map_location="cpu")["module"] - denoiser.load_state_dict(state_dict) - denoiser.eval() - denoiser.to(device) - return denoiser + return create_default_denoiser(device) + return load_denoiser_model(run_dir, device) @torch.inference_mode() diff --git a/resemble_enhance/denoiser/safetensors_loader.py b/resemble_enhance/denoiser/safetensors_loader.py new file mode 100644 index 0000000..3d4a62b --- /dev/null +++ b/resemble_enhance/denoiser/safetensors_loader.py @@ -0,0 +1,247 @@ +""" +Enhanced model loader for denoiser with safetensors support and JSON configs. +Provides efficient loading without state_dict filtering when using safetensors. +""" +import json +import logging +from pathlib import Path +from typing import Dict, Any, Union + +import torch +from safetensors.torch import load_file + +from .denoiser import Denoiser +from .hparams import HParams + +logger = logging.getLogger(__name__) + + +class JSONConfig: + """Simple config class that works with JSON files instead of OmegaConf.""" + + def __init__(self, config_dict: Dict[str, Any]): + self._config = config_dict + # Set attributes for easy access + for key, value in config_dict.items(): + if isinstance(value, dict): + setattr(self, key, JSONConfig(value)) + else: + setattr(self, key, value) + + @classmethod + def load(cls, json_path: Union[str, Path]) -> 'JSONConfig': + """Load config from JSON file.""" + with open(json_path, 'r') as f: + config_dict = json.load(f) + return cls(config_dict) + + def get(self, key: str, default=None): + """Get config value with default.""" + return getattr(self, key, default) + + def to_dict(self) -> Dict[str, Any]: + """Convert back to dictionary.""" + return self._config + + +def load_denoiser_from_safetensors(model_dir: Union[str, Path], device: str = "cpu") -> Denoiser: + """Load denoiser model from safetensors format. + + Args: + model_dir: Directory containing model.safetensors and config.json + device: Device to load the model on + + Returns: + Loaded Denoiser model + """ + model_path = Path(model_dir) + + # Load config + config_path = model_path / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + config = JSONConfig.load(config_path) + + # Create HParams with default values, then update from config + hp = HParams() + + # For frozen dataclasses, we need to use object.__setattr__ + config_dict = config.to_dict() + for key, value in config_dict.items(): + if hasattr(hp, key): + try: + object.__setattr__(hp, key, value) + except Exception: + logger.warning(f"Could not set {key}={value} on HParams") + + # Create model + model = Denoiser(hp) + + # Load weights from safetensors + weights_path = model_path / "model.safetensors" + if not weights_path.exists(): + raise FileNotFoundError(f"Model weights not found: {weights_path}") + + state_dict = load_file(weights_path, device=device) + + # No filtering needed - safetensors already contains only denoiser weights + model.load_state_dict(state_dict) + model.eval() + model.to(device) + + logger.info(f"Loaded denoiser model from safetensors: {model_path}") + return model + + +def load_denoiser_model(run_dir: Union[str, Path, None], device: str = "cpu") -> Denoiser: + """Load denoiser model from either safetensors or DeepSpeed checkpoint. + + Args: + run_dir: Path to model directory (safetensors) or checkpoint directory (DeepSpeed) + device: Device to load the model on + + Returns: + Loaded Denoiser model + """ + if run_dir is None: + return create_default_denoiser(device) + + run_dir = Path(run_dir) + + # Check if this is a safetensors model directory + if (run_dir / "model.safetensors").exists() and (run_dir / "config.json").exists(): + logger.info("Loading denoiser from safetensors format") + return load_denoiser_from_safetensors(run_dir, device) + + # Fall back to DeepSpeed checkpoint loading + logger.info("Loading denoiser from DeepSpeed checkpoint format") + return load_denoiser_from_deepspeed(run_dir, device) + + +def load_denoiser_from_deepspeed(run_dir: Path, device: str = "cpu") -> Denoiser: + """Load denoiser model from DeepSpeed checkpoint (legacy format). + + Args: + run_dir: Path to the model checkpoint directory + device: Device to load the model on + + Returns: + Loaded Denoiser model ready for inference + """ + # Load hparams + hparams_path = run_dir / "hparams.yaml" + if not hparams_path.exists(): + logger.warning(f"hparams.yaml not found in {run_dir}, using defaults") + hp = HParams() + else: + hp = HParams.load(run_dir) + + # Create model + model = Denoiser(hp) + + # Load the state dict from DeepSpeed checkpoint + ckpt_path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" + if not ckpt_path.exists(): + logger.warning(f"Model checkpoint not found at {ckpt_path}, returning default model") + return create_default_denoiser(device) + + state_dict = torch.load(ckpt_path, map_location="cpu")["module"] + model.load_state_dict(state_dict) + model.eval() + model.to(device) + + logger.info(f"Loaded denoiser model from DeepSpeed checkpoint: {run_dir}") + return model + + +def load_denoiser_from_enhancer_checkpoint(run_dir: Union[str, Path, None], device: str = "cpu") -> Denoiser: + """Load denoiser model from an enhancer checkpoint. + + This extracts the denoiser weights from an enhancer checkpoint that contains + both enhancer and denoiser weights. + + Args: + run_dir: Path to the enhancer checkpoint directory (None for default model) + device: Device to load the model on + + Returns: + Loaded Denoiser model ready for inference + """ + # If no run_dir provided, create default model + if run_dir is None: + return create_default_denoiser(device) + + run_dir = Path(run_dir) + + # Check if this is a safetensors enhancer directory with separate denoiser + denoiser_safetensors_dir = run_dir.parent / "denoiser" if run_dir.parent else None + if (denoiser_safetensors_dir and + (denoiser_safetensors_dir / "model.safetensors").exists() and + (denoiser_safetensors_dir / "config.json").exists()): + logger.info("Loading denoiser from separate safetensors directory") + return load_denoiser_from_safetensors(denoiser_safetensors_dir, device) + + # Try to load denoiser hparams first, fall back to enhancer hparams + denoiser_hp_path = run_dir / "denoiser_hparams.yaml" + if denoiser_hp_path.exists(): + hp = HParams.load(denoiser_hp_path) + else: + # Load enhancer hparams and use denoiser settings from it + from ..enhancer.hparams import HParams as EnhancerHParams + enhancer_hp_path = run_dir / "hparams.yaml" + if enhancer_hp_path.exists(): + enhancer_hp = EnhancerHParams.load(run_dir) + + # Create denoiser hparams from enhancer config + hp = HParams() + # Copy relevant settings if they exist + if hasattr(enhancer_hp, 'denoiser_run_dir') and enhancer_hp.denoiser_run_dir: + denoiser_run_dir = Path(enhancer_hp.denoiser_run_dir) + if (denoiser_run_dir / "hparams.yaml").exists(): + hp = HParams.load(denoiser_run_dir) + else: + # No hparams found, use default + hp = HParams() + + model = Denoiser(hp) + + # Load the state dict from enhancer checkpoint + ckpt_path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" + if not ckpt_path.exists(): + # No checkpoint found, return default model + return create_default_denoiser(device) + + state_dict = torch.load(ckpt_path, map_location="cpu")["module"] + + # Extract only denoiser weights + denoiser_state_dict = {k.replace('denoiser.', '', 1): v for k, v in state_dict.items() if k.startswith('denoiser.')} + + if not denoiser_state_dict: + # No denoiser weights found, return default model + logger.warning("No denoiser weights found in enhancer checkpoint, using default model") + return create_default_denoiser(device) + + model.load_state_dict(denoiser_state_dict) + model.eval() + model.to(device) + + logger.info(f"Loaded denoiser from enhancer checkpoint: {run_dir}") + return model + + +def create_default_denoiser(device: str = "cpu") -> Denoiser: + """Create a default denoiser model with default hyperparameters. + + Args: + device: Device to create the model on + + Returns: + Default Denoiser model (not trained) + """ + hp = HParams() + model = Denoiser(hp) + model.eval() + model.to(device) + logger.info("Created default denoiser model") + return model diff --git a/resemble_enhance/enhancer/download.py b/resemble_enhance/enhancer/download.py index 614b9a4..02a9973 100644 --- a/resemble_enhance/enhancer/download.py +++ b/resemble_enhance/enhancer/download.py @@ -11,6 +11,8 @@ def get_source_url(relpath): return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true" +def get_safetensors_url(relpath): + return f"https://huggingface.co/rsxdalv/resemble-enhance/resolve/main/{relpath}?download=true" def get_target_path(relpath: str | Path, run_dir: str | Path | None = None): if run_dir is None: @@ -18,7 +20,25 @@ def get_target_path(relpath: str | Path, run_dir: str | Path | None = None): return Path(run_dir) / relpath -def download(run_dir: str | Path | None = None): +def download(run_dir: str | Path | None = None, safetensors: bool = False) -> Path: + relpaths_safetensors = [ + "denoiser/config.json", + "denoiser/model.safetensors", + "denoiser/model_info.json", + "enhancer/config.json", + "enhancer/model.safetensors", + "enhancer/model_info.json", + ] + if safetensors: + for relpath in relpaths_safetensors: + path = get_target_path(relpath, run_dir=run_dir) + if path.exists(): + continue + url = get_safetensors_url(relpath) + path.parent.mkdir(parents=True, exist_ok=True) + torch.hub.download_url_to_file(url, str(path)) + return get_target_path("", run_dir=run_dir) + relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"] for relpath in relpaths: path = get_target_path(relpath, run_dir=run_dir) diff --git a/resemble_enhance/enhancer/enhancer_inference.py b/resemble_enhance/enhancer/enhancer_inference.py new file mode 100644 index 0000000..cf55253 --- /dev/null +++ b/resemble_enhance/enhancer/enhancer_inference.py @@ -0,0 +1,204 @@ +""" +Enhancer model for inference without training dependencies. +This is a copy of the enhancer module but with inference-only imports. +""" +import logging + +import matplotlib.pyplot as plt +import pandas as pd +import torch +from torch import Tensor, nn +from torch.distributions import Beta + +from ..common import Normalizer +from ..melspec import MelSpectrogram +from .hparams import HParams +from .lcfm import CFM, IRMAE, LCFM +from .univnet import UnivNet + +logger = logging.getLogger(__name__) + + +# No-op decorators for inference +def global_leader_only(fn): + """No-op decorator for inference - just returns the function as-is.""" + return fn + + +# Simple TrainLoop replacement for inference +class TrainLoop: + """No-op TrainLoop class for inference.""" + + @staticmethod + def get_running_loop(): + """Always return None for inference (no training loop).""" + return None + + +def _maybe(fn): + def _fn(*args): + if args[0] is None: + return None + return fn(*args) + + return _fn + + +def _normalize_wav(x: Tensor): + return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) + + +class EnhancerInference(nn.Module): + def __init__(self, hp: HParams): + super().__init__() + self.hp = hp + + n_mels = self.hp.num_mels + vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim + latent_dim = self.hp.lcfm_latent_dim + + self.lcfm = LCFM( + IRMAE( + input_dim=n_mels, + output_dim=vocoder_input_dim, + latent_dim=latent_dim, + ), + CFM( + cond_dim=n_mels, + output_dim=self.hp.lcfm_latent_dim, + solver_nfe=hp.cfm_solver_nfe, + solver_method=hp.cfm_solver_method, + time_mapping_divisor=hp.cfm_time_mapping_divisor, + ), + z_scale=self.hp.lcfm_z_scale, + ) + + self.lcfm.set_mode_(self.hp.lcfm_training_mode) + + self.mel_fn = MelSpectrogram(hp) + self.vocoder = UnivNet(self.hp, vocoder_input_dim) + + # For inference, denoiser will be set separately if needed + self.denoiser = None + self.normalizer = Normalizer() + + self._eval_lambd = 0.0 + + self.dummy: Tensor + self.register_buffer("dummy", torch.zeros(1)) + + @property + def mel_fn(self): + return self.vocoder.mel_fn + + def configure_denoiser_(self, denoiser): + self.denoiser = denoiser + + def configurate_(self, **kwargs): + """Configure model parameters for inference. + + Args: + nfe: number of function evaluations + solver: solver method + lambd: denoiser strength [0, 1] + tau: prior temperature [0, 1] + """ + if "nfe" in kwargs and "solver" in kwargs: + self.lcfm.cfm.solver.configurate_(kwargs["nfe"], kwargs["solver"]) + elif "nfe" in kwargs: + self.lcfm.cfm.solver.configurate_(kwargs["nfe"], None) + elif "solver" in kwargs: + self.lcfm.cfm.solver.configurate_(None, kwargs["solver"]) + + if "tau" in kwargs: + self.lcfm.eval_tau_(kwargs["tau"]) + if "lambd" in kwargs: + self._eval_lambd = kwargs["lambd"] + + def _may_denoise(self, x: Tensor, y: Tensor | None = None): + if self.hp.lcfm_training_mode == "cfm" and self.denoiser is not None: + return self.denoiser(x, y) + return x + + def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None): + """Forward pass for inference. + + Args: + x: (b t), mix wavs (fg + bg) + y: (b t), fg clean wavs + z: (b t), fg distorted wavs + Returns: + o: (b t), reconstructed wavs + """ + assert x.dim() == 2, f"Expected (b t), got {x.size()}" + assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}" + + if self.hp.lcfm_training_mode == "cfm": + self.normalizer.eval() + + x = _normalize_wav(x) + y = _maybe(_normalize_wav)(y) + z = _maybe(_normalize_wav)(z) + + x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t) + + if self.hp.lcfm_training_mode == "cfm": + lambd = self._eval_lambd + if lambd == 0: + x_mel_denoised = x_mel_original + else: + x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False) + x_mel_denoised = x_mel_denoised.detach() + x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original + else: + x_mel_denoised = x_mel_original + + y_mel = _maybe(self.to_mel)(y) # (b d t) + y_mel = _maybe(self.normalizer)(y_mel) + + if hasattr(self.hp, 'force_gaussian_prior') and self.hp.force_gaussian_prior: + lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=None) # (b d t) + else: + lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t) + + if lcfm_decoded is None: + o = None + else: + o = self.vocoder(lcfm_decoded, y) + + return o + + def to_mel(self, x, drop_last=True): + """Convert waveform to mel-spectrogram. + + Args: + x: (b t), wavs + Returns: + o: (b c t), mels + """ + if drop_last: + return self.mel_fn(x)[..., :-1] # (b d t) + return self.mel_fn(x) + + def to_mel(self, x, drop_last=True): + if drop_last: + return self.mel_fn(x)[..., :-1] # (b d t) + return self.mel_fn(x) + + @global_leader_only + @torch.no_grad() + def _visualize(self, original_mel, denoised_mel): + loop = TrainLoop.get_running_loop() + if loop is None or loop.global_step % 100 != 0: + return + + plt.figure(figsize=(6, 6)) + plt.subplot(211) + plt.title("Original") + plt.imshow(original_mel[0].cpu().numpy(), origin="lower", interpolation="none") + plt.subplot(212) + plt.title("Denoised") + plt.imshow(denoised_mel[0].cpu().numpy(), origin="lower", interpolation="none") + + loop.save_current_step_viz(plt.gcf(), self.__class__.__name__, ".png") + plt.close() diff --git a/resemble_enhance/enhancer/inference.py b/resemble_enhance/enhancer/inference.py index c8f59d4..16f16e5 100644 --- a/resemble_enhance/enhancer/inference.py +++ b/resemble_enhance/enhancer/inference.py @@ -6,36 +6,35 @@ from ..inference import inference from .download import download -from .train import Enhancer, HParams +from .safetensors_loader import load_enhancer_model +from .hparams import HParams logger = logging.getLogger(__name__) @cache -def load_enhancer(run_dir: str | Path | None, device): - run_dir = download(run_dir) - hp = HParams.load(run_dir) - enhancer = Enhancer(hp) - path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" - state_dict = torch.load(path, map_location="cpu")["module"] - enhancer.load_state_dict(state_dict) - enhancer.eval() - enhancer.to(device) - return enhancer +def load_enhancer(run_dir: str | Path | None, device, skip_download: bool = False): + if not skip_download: + run_dir = download(run_dir) + return load_enhancer_model(run_dir, device) @torch.inference_mode() def denoise(dwav, sr, device, run_dir=None): - enhancer = load_enhancer(run_dir, device) - return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device) + from ..denoiser.safetensors_loader import load_denoiser_from_enhancer_checkpoint + from ..inference import inference + + # Load denoiser from enhancer checkpoint since they're stored together + denoiser = load_denoiser_from_enhancer_checkpoint(run_dir, device) + return inference(model=denoiser, dwav=dwav, sr=sr, device=device) @torch.inference_mode() -def enhance(dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None): +def enhance(dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None, skip_download=False): assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}" assert solver in ("midpoint", "rk4", "euler"), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}" assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}" assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}" - enhancer = load_enhancer(run_dir, device) + enhancer = load_enhancer(run_dir, device, skip_download) enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau) return inference(model=enhancer, dwav=dwav, sr=sr, device=device) diff --git a/resemble_enhance/enhancer/safetensors_loader.py b/resemble_enhance/enhancer/safetensors_loader.py new file mode 100644 index 0000000..2beba1c --- /dev/null +++ b/resemble_enhance/enhancer/safetensors_loader.py @@ -0,0 +1,176 @@ +""" +Enhanced model loader for enhancer with safetensors support and JSON configs. +Provides efficient loading without state_dict filtering when using safetensors. +""" +import json +import logging +from pathlib import Path +from typing import Dict, Any, Union + +import torch +from safetensors.torch import load_file + +from .enhancer_inference import EnhancerInference +from .hparams import HParams + +logger = logging.getLogger(__name__) + + +class JSONConfig: + """Simple config class that works with JSON files instead of OmegaConf.""" + + def __init__(self, config_dict: Dict[str, Any]): + self._config = config_dict + # Set attributes for easy access + for key, value in config_dict.items(): + if isinstance(value, dict): + setattr(self, key, JSONConfig(value)) + else: + setattr(self, key, value) + + @classmethod + def load(cls, json_path: Union[str, Path]) -> 'JSONConfig': + """Load config from JSON file.""" + with open(json_path, 'r') as f: + config_dict = json.load(f) + return cls(config_dict) + + def get(self, key: str, default=None): + """Get config value with default.""" + return getattr(self, key, default) + + def to_dict(self) -> Dict[str, Any]: + """Convert back to dictionary.""" + return self._config + + +def load_enhancer_from_safetensors(model_dir: Union[str, Path], device: str = "cpu") -> EnhancerInference: + """Load enhancer model from safetensors format. + + Args: + model_dir: Directory containing model.safetensors and config.json + device: Device to load the model on + + Returns: + Loaded EnhancerInference model + """ + model_path = Path(model_dir) + + # Load config + config_path = model_path / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + config = JSONConfig.load(config_path) + + # Create HParams with default values, then update from config + hp = HParams() + + # For frozen dataclasses, we need to use object.__setattr__ + config_dict = config.to_dict() + for key, value in config_dict.items(): + if hasattr(hp, key): + try: + object.__setattr__(hp, key, value) + except Exception: + logger.warning(f"Could not set {key}={value} on HParams") + + # Create model + model = EnhancerInference(hp) + + # Load weights from safetensors + weights_path = model_path / "model.safetensors" + if not weights_path.exists(): + raise FileNotFoundError(f"Model weights not found: {weights_path}") + + state_dict = load_file(weights_path, device=device) + + # No filtering needed - safetensors already contains only enhancer weights + model.load_state_dict(state_dict) + model.eval() + model.to(device) + + logger.info(f"Loaded enhancer model from safetensors: {model_path}") + return model + + +def load_enhancer_model(run_dir: Union[str, Path, None], device: str = "cpu") -> EnhancerInference: + """Load enhancer model from either safetensors or DeepSpeed checkpoint. + + Args: + run_dir: Path to model directory (safetensors) or checkpoint directory (DeepSpeed) + device: Device to load the model on + + Returns: + Loaded EnhancerInference model + """ + if run_dir is None: + return create_default_enhancer(device) + + run_dir = Path(run_dir) + + # Check if this is a safetensors model directory + if (run_dir / "model.safetensors").exists() and (run_dir / "config.json").exists(): + logger.info("Loading from safetensors format") + return load_enhancer_from_safetensors(run_dir, device) + + # Fall back to DeepSpeed checkpoint loading + logger.info("Loading from DeepSpeed checkpoint format") + return load_enhancer_from_deepspeed(run_dir, device) + + +def load_enhancer_from_deepspeed(run_dir: Path, device: str = "cpu") -> EnhancerInference: + """Load enhancer model from DeepSpeed checkpoint (legacy format). + + Args: + run_dir: Path to the model checkpoint directory + device: Device to load the model on + + Returns: + Loaded EnhancerInference model ready for inference + """ + # Load hparams + hparams_path = run_dir / "hparams.yaml" + if not hparams_path.exists(): + logger.warning(f"hparams.yaml not found in {run_dir}, using defaults") + hp = HParams() + else: + hp = HParams.load(run_dir) + + # Create model + model = EnhancerInference(hp) + + # Load the state dict from DeepSpeed checkpoint + ckpt_path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" + if not ckpt_path.exists(): + logger.warning(f"Model checkpoint not found at {ckpt_path}, returning default model") + return create_default_enhancer(device) + + state_dict = torch.load(ckpt_path, map_location="cpu")["module"] + + # Filter out denoiser weights since EnhancerInference doesn't have a denoiser + filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('denoiser.')} + + model.load_state_dict(filtered_state_dict) + model.eval() + model.to(device) + + logger.info(f"Loaded enhancer model from DeepSpeed checkpoint: {run_dir}") + return model + + +def create_default_enhancer(device: str = "cpu") -> EnhancerInference: + """Create a default enhancer model with default hyperparameters. + + Args: + device: Device to create the model on + + Returns: + Default EnhancerInference model (not trained) + """ + hp = HParams() + model = EnhancerInference(hp) + model.eval() + model.to(device) + logger.info("Created default enhancer model") + return model diff --git a/resemble_enhance/utils/distributed_inference.py b/resemble_enhance/utils/distributed_inference.py new file mode 100644 index 0000000..afef0ff --- /dev/null +++ b/resemble_enhance/utils/distributed_inference.py @@ -0,0 +1,24 @@ +""" +Simplified distributed utilities for inference without DeepSpeed dependencies. +This module provides no-op versions of distributed functions for inference use. +""" + + +def global_leader_only(fn): + """No-op decorator for inference - just returns the function as-is.""" + return fn + + +def local_leader_only(fn): + """No-op decorator for inference - just returns the function as-is.""" + return fn + + +def is_global_leader(): + """Always return True for inference (single process).""" + return True + + +def is_local_leader(): + """Always return True for inference (single process).""" + return True diff --git a/resemble_enhance/utils/train_loop_inference.py b/resemble_enhance/utils/train_loop_inference.py new file mode 100644 index 0000000..6445a70 --- /dev/null +++ b/resemble_enhance/utils/train_loop_inference.py @@ -0,0 +1,13 @@ +""" +Simplified TrainLoop utilities for inference without training dependencies. +This module provides no-op versions of TrainLoop functions for inference use. +""" + + +class TrainLoop: + """No-op TrainLoop class for inference.""" + + @staticmethod + def get_running_loop(): + """Always return None for inference (no training loop).""" + return None diff --git a/setup.py b/setup.py index 8f43617..834d027 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,11 @@ def write_version(version_core, pre_release=True): long_description_content_type="text/markdown", packages=find_packages(), install_requires=requirements, + extra_require={ + "train": [ + "deepspeed>=0.12.4", + ], + }, url="https://github.com/resemble-ai/resemble-enhance", author="Resemble AI", author_email="team@resemble.ai",