|
19 | 19 | ) |
20 | 20 | from xturing.engines.quant_utils.peft_utils import LoraConfig as peftLoraConfig |
21 | 21 | from xturing.engines.quant_utils.peft_utils import prepare_model_for_kbit_training |
| 22 | +from xturing.utils.logging import configure_logger |
22 | 23 | from xturing.utils.loss_fns import CrossEntropyLoss |
| 24 | +from xturing.utils.utils import assert_install_itrex |
23 | 25 |
|
24 | 26 |
|
| 27 | +logger = configure_logger(__name__) |
| 28 | + |
25 | 29 | class CausalEngine(BaseEngine): |
26 | 30 | def __init__( |
27 | 31 | self, |
@@ -60,18 +64,34 @@ def __init__( |
60 | 64 | self.tokenizer = tokenizer |
61 | 65 | elif model_name is not None: |
62 | 66 | if load_8bit: |
63 | | - device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
64 | | - self.model = AutoModelForCausalLM.from_pretrained( |
65 | | - model_name, |
66 | | - torch_dtype=DEFAULT_DTYPE, |
67 | | - load_in_8bit=True, |
68 | | - device_map=device_map, |
69 | | - trust_remote_code=trust_remote_code, |
70 | | - **kwargs, |
71 | | - ) |
72 | | - for param in self.model.parameters(): |
73 | | - param.data = param.data.contiguous() |
74 | | - self.model = prepare_model_for_int8_training(self.model) |
| 67 | + use_itrex = DEFAULT_DEVICE.type == "cpu" |
| 68 | + if use_itrex: |
| 69 | + logger.info("CUDA is not available, using CPU instead, running the model with itrex.") |
| 70 | + assert_install_itrex() |
| 71 | + # quantize model with weight-only quantization |
| 72 | + from intel_extension_for_transformers.transformers import AutoModelForCausalLM as ItrexAutoModelForCausalLM |
| 73 | + from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig |
| 74 | + woq_config = WeightOnlyQuantConfig(weight_dtype='int8') |
| 75 | + self.model = ItrexAutoModelForCausalLM.from_pretrained( |
| 76 | + model_name, |
| 77 | + quantization_config=woq_config, |
| 78 | + trust_remote_code=trust_remote_code, |
| 79 | + use_llm_runtime=False, |
| 80 | + **kwargs) |
| 81 | + logger.info("Loaded int8 model from Itrex.") |
| 82 | + else: |
| 83 | + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
| 84 | + self.model = AutoModelForCausalLM.from_pretrained( |
| 85 | + model_name, |
| 86 | + torch_dtype=DEFAULT_DTYPE, |
| 87 | + load_in_8bit=True, |
| 88 | + device_map=device_map, |
| 89 | + trust_remote_code=trust_remote_code, |
| 90 | + **kwargs, |
| 91 | + ) |
| 92 | + for param in self.model.parameters(): |
| 93 | + param.data = param.data.contiguous() |
| 94 | + self.model = prepare_model_for_int8_training(self.model) |
75 | 95 | else: |
76 | 96 | self.model = AutoModelForCausalLM.from_pretrained( |
77 | 97 | model_name, |
|
0 commit comments