From 7ab712429da0a0e0ca9871d9f0208705940bdba6 Mon Sep 17 00:00:00 2001 From: Koke_Cacao Date: Sat, 26 Oct 2024 13:25:32 -0400 Subject: [PATCH] Fix custom fwd and bwd for older PyTorch versions https://github.com/state-spaces/mamba/pull/596#issuecomment-2439662703 --- mamba_ssm/utils/torch.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mamba_ssm/utils/torch.py b/mamba_ssm/utils/torch.py index afe1dfcf..37df47c8 100644 --- a/mamba_ssm/utils/torch.py +++ b/mamba_ssm/utils/torch.py @@ -1,16 +1,18 @@ import torch from functools import partial +from typing import Callable - -def custom_amp_decorator(dec, cuda_amp_deprecated): - def decorator(func): - return dec(func) if not cuda_amp_deprecated else partial(dec, func, device_type="cuda") +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) return decorator -if hasattr(torch.amp, "custom_fwd"): +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True - from torch.amp import custom_fwd, custom_bwd + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_fwd, custom_bwd