Skip to content

Commit ef51f29

Browse files
Merge pull request #268 from yiliu30/itrex_woq
Integrate ITREX to support int8 model on the CPU-only devices
2 parents 77911b4 + e4cc29c commit ef51f29

File tree

7 files changed

+124
-15
lines changed

7 files changed

+124
-15
lines changed

examples/models/gpt2/gpt2_woq.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# from xturing.datasets.instruction_dataset import InstructionDataset
2+
from xturing.models import BaseModel
3+
4+
# Initializes the model: Quantize model with weight only algorithms and
5+
# replace the linear with itrex's qbits_linear kernel
6+
model = BaseModel.create("gpt2_int8")
7+
8+
# Once the model has been quantized, you can do inferences directly
9+
output = model.generate(texts=["Why LLM models are becoming so important?"])
10+
print(output)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# from xturing.datasets.instruction_dataset import InstructionDataset
2+
from xturing.models import BaseModel
3+
4+
# Initializes the model: Quantize model with weight only algorithms and
5+
# replace the linear with itrex's qbits_linear kernel
6+
model = BaseModel.create("llama2_int8")
7+
8+
# Once the model has been quantized, you can do inferences directly
9+
output = model.generate(texts=["Why LLM models are becoming so important?"])
10+
print(output)

src/xturing/config/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import torch
22

33
from xturing.utils.interactive import is_interactive_execution
4+
from xturing.utils.logging import configure_logger
5+
from xturing.utils.utils import assert_install_itrex
6+
7+
logger = configure_logger(__name__)
48

59
# check if cuda is available, if not use cpu and throw warning
610
DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
711
DEFAULT_DTYPE = torch.float16 if DEFAULT_DEVICE.type == "cuda" else torch.float32
812
IS_INTERACTIVE = is_interactive_execution()
913

1014
if DEFAULT_DEVICE.type == "cpu":
11-
print("WARNING: CUDA is not available, using CPU instead, can be very slow")
15+
logger.warning("WARNING: CUDA is not available, using CPU instead, can be very slow")
1216

1317

1418
def assert_not_cpu_int8():
1519
assert DEFAULT_DEVICE.type != "cpu", "Int8 models are not supported on CPU"
20+
21+
def assert_cpu_int8_on_itrex():
22+
if DEFAULT_DEVICE.type == "cpu":
23+
assert_install_itrex()

src/xturing/engines/causal.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919
)
2020
from xturing.engines.quant_utils.peft_utils import LoraConfig as peftLoraConfig
2121
from xturing.engines.quant_utils.peft_utils import prepare_model_for_kbit_training
22+
from xturing.utils.logging import configure_logger
2223
from xturing.utils.loss_fns import CrossEntropyLoss
24+
from xturing.utils.utils import assert_install_itrex
2325

2426

27+
logger = configure_logger(__name__)
28+
2529
class CausalEngine(BaseEngine):
2630
def __init__(
2731
self,
@@ -60,18 +64,34 @@ def __init__(
6064
self.tokenizer = tokenizer
6165
elif model_name is not None:
6266
if load_8bit:
63-
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
64-
self.model = AutoModelForCausalLM.from_pretrained(
65-
model_name,
66-
torch_dtype=DEFAULT_DTYPE,
67-
load_in_8bit=True,
68-
device_map=device_map,
69-
trust_remote_code=trust_remote_code,
70-
**kwargs,
71-
)
72-
for param in self.model.parameters():
73-
param.data = param.data.contiguous()
74-
self.model = prepare_model_for_int8_training(self.model)
67+
use_itrex = DEFAULT_DEVICE.type == "cpu"
68+
if use_itrex:
69+
logger.info("CUDA is not available, using CPU instead, running the model with itrex.")
70+
assert_install_itrex()
71+
# quantize model with weight-only quantization
72+
from intel_extension_for_transformers.transformers import AutoModelForCausalLM as ItrexAutoModelForCausalLM
73+
from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig
74+
woq_config = WeightOnlyQuantConfig(weight_dtype='int8')
75+
self.model = ItrexAutoModelForCausalLM.from_pretrained(
76+
model_name,
77+
quantization_config=woq_config,
78+
trust_remote_code=trust_remote_code,
79+
use_llm_runtime=False,
80+
**kwargs)
81+
logger.info("Loaded int8 model from Itrex.")
82+
else:
83+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
84+
self.model = AutoModelForCausalLM.from_pretrained(
85+
model_name,
86+
torch_dtype=DEFAULT_DTYPE,
87+
load_in_8bit=True,
88+
device_map=device_map,
89+
trust_remote_code=trust_remote_code,
90+
**kwargs,
91+
)
92+
for param in self.model.parameters():
93+
param.data = param.data.contiguous()
94+
self.model = prepare_model_for_int8_training(self.model)
7595
else:
7696
self.model = AutoModelForCausalLM.from_pretrained(
7797
model_name,

src/xturing/models/causal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tqdm import tqdm
99
from transformers import BatchEncoding
1010

11-
from xturing.config import DEFAULT_DEVICE, assert_not_cpu_int8
11+
from xturing.config import DEFAULT_DEVICE, assert_not_cpu_int8, assert_cpu_int8_on_itrex
1212
from xturing.config.config_data_classes import FinetuningConfig, GenerationConfig
1313
from xturing.config.read_config import load_config
1414
from xturing.datasets.instruction_dataset import InstructionDataset
@@ -320,7 +320,7 @@ def __init__(
320320
model_name: Optional[str] = None,
321321
**kwargs,
322322
):
323-
assert_not_cpu_int8()
323+
assert_cpu_int8_on_itrex()
324324
super().__init__(
325325
engine,
326326
weights_path=weights_path,

src/xturing/utils/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,32 @@ def _index_samples(samples: List[Any], logger: logging.Logger):
150150
logger.info(f"Evaluating {len(indices)} samples")
151151
work_items = [(samples[i], i) for i in indices]
152152
return work_items
153+
154+
155+
def is_itrex_available():
156+
"""
157+
Check the availability of 'intel_extension_for_transformers' as an optional dependency.
158+
159+
Returns:
160+
bool: True if 'intel_extension_for_transformers' is available, False otherwise.
161+
162+
Raises:
163+
subprocess.CalledProcessError: If the pip installation process fails.
164+
"""
165+
import importlib
166+
if importlib.util.find_spec("intel_extension_for_transformers") is not None:
167+
return True
168+
else:
169+
try:
170+
import subprocess
171+
import sys
172+
subprocess.check_call([sys.executable, "-m", "pip", "install", "intel-extension-for-transformers"])
173+
return importlib.util.find_spec("intel_extension_for_transformers") is not None
174+
except:
175+
return False
176+
177+
def assert_install_itrex():
178+
assert is_itrex_available(), (
179+
"To run int8 or k-bits model on cpu, please install the `intel-extension-for-transformers` package."
180+
"You can install it with `pip install intel-extension-for-transformers`."
181+
)

tests/xturing/models/test_gpt2_model.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,35 @@ def test_saving_loading_model_lora():
101101

102102
model2 = BaseModel.load(str(saving_path))
103103
model2.generate(texts=["Why are the LLM so important?"])
104+
105+
106+
import os
107+
108+
def disable_cuda(func):
109+
def wrapper(*args, **kwargs):
110+
# Save the current value of CUDA_VISIBLE_DEVICES
111+
original_cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
112+
# Set CUDA_VISIBLE_DEVICES to -1 to disable CUDA
113+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
114+
try:
115+
# Call the decorated function
116+
return func(*args, **kwargs)
117+
except Exception as e:
118+
# Handle exceptions here
119+
print(f"An error occurred: {e}")
120+
finally:
121+
# Restore the original value of CUDA_VISIBLE_DEVICES
122+
if original_cuda_visible_devices is not None:
123+
os.environ['CUDA_VISIBLE_DEVICES'] = original_cuda_visible_devices
124+
else:
125+
# If CUDA_VISIBLE_DEVICES was not set before, remove it from the environment
126+
if 'CUDA_VISIBLE_DEVICES' in os.environ:
127+
del os.environ['CUDA_VISIBLE_DEVICES']
128+
129+
return wrapper
130+
131+
@disable_cuda
132+
def test_gpt2_int8_woq_cpu():
133+
# test quantize gpt2 with itrex
134+
other_model = BaseModel.create("gpt2_int8")
135+
assert other_model.generate(texts="I want to") != ""

0 commit comments

Comments
 (0)