Skip to content

Commit 0effb4e

Browse files
codereview revision
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent d121cfb commit 0effb4e

File tree

1 file changed

+6
-2
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+6
-2
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
from typing import Any, Dict, List, Optional, Union
33

44
import torch
5-
from compressed_tensors.utils import align_module_device, update_offload_parameter
5+
from compressed_tensors.utils import (
6+
align_module_device,
7+
get_execution_device,
8+
update_offload_parameter,
9+
)
610
from loguru import logger
711
from pydantic import ConfigDict
812
from torch.nn import Module
@@ -590,7 +594,7 @@ def _forward_input_with_kwargs(
590594
kwargs = input_kwargs or self._module_kwargs
591595
kwargs = _sanitize_kwargs(kwargs, module)
592596

593-
inputs = inputs.to(next(module.parameters()).device)
597+
inputs = inputs.to(get_execution_device(module))
594598

595599
return module(inputs, **kwargs)[0]
596600

0 commit comments

Comments
 (0)