diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index f318dd42..59e17189 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -319,6 +319,7 @@ class Params: layer_keys_llama_mlp self.lm.expect_keys += \ expect_keys_llama + self.lm.supports_tp = True self.vt_prefix = "vision_tower." self.vt.keys.update({ @@ -478,6 +479,7 @@ class Params: self.lm.attention_bias_qkv = True self.lm.mrope = True self.lm.rope_freq_half = True + self.lm.supports_tp = True self.vt_prefix = "visual." if arch_string == "Qwen2VLForConditionalGeneration": diff --git a/exllamav2/tensor_p.py b/exllamav2/tensor_p.py index 2ba6f85a..3ebc7f16 100644 --- a/exllamav2/tensor_p.py +++ b/exllamav2/tensor_p.py @@ -351,16 +351,24 @@ def allgather( return bc_tensors - def copy_pinned( - self, - buffer: int, - inputs: torch.Tensor - ): +# def copy_pinned( +# self, +# buffer: int, +# inputs: torch.Tensor +# ): +# pt = self.pinned_temp[buffer][:inputs.numel()] +# pt = pt.view(inputs.shape) +# pt.copy_(inputs) +# return pt + + def copy_pinned(self, buffer: int, inputs: torch.Tensor): pt = self.pinned_temp[buffer][:inputs.numel()] pt = pt.view(inputs.shape) - pt.copy_(inputs) - return pt + # Bypass PyTorch entirely - direct memory copy + import ctypes + ctypes.memmove(pt.data_ptr(), inputs.data_ptr(), inputs.numel() * inputs.element_size()) + return pt def add_residual( self,