Skip to content

Commit c601238

Browse files
author
Swati Allabadi
committed
PP tests
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
1 parent bafc879 commit c601238

File tree

5 files changed

+1041
-11
lines changed

5 files changed

+1041
-11
lines changed

QEfficient/finetune/experimental/core/config_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dataclasses import asdict, dataclass, field, fields, is_dataclass
1717
from pathlib import Path
1818
from typing import Any, Dict, List, Mapping, Optional, Union
19+
from QEfficient.finetune.experimental.core.utils import constants
1920

2021
import yaml
2122
from transformers.hf_argparser import HfArgumentParser
@@ -855,7 +856,7 @@ def get_model_config(self) -> Dict[str, Any]:
855856
training_dtype = training_config.get("torch_dtype")
856857
if training_dtype:
857858
# Convert from training format (fp16/bf16) to model format (float16/bfloat16)
858-
dtype_mapping = {"fp16": "float16", "bf16": "bfloat16"}
859+
dtype_mapping = dtype_mapping = constants.DTYPE_MAPPING
859860
model_config["torch_dtype"] = dtype_mapping.get(training_dtype, "auto")
860861

861862
return model_config

QEfficient/finetune/experimental/core/model.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,6 @@ def _resolve_auto_class(auto_class_name: str) -> Type:
105105
)
106106
return getattr(transformers, auto_class_name)
107107

108-
# def _build_quant_config(self) -> Optional[BitsAndBytesConfig]:
109-
# if not self.model_kwargs.get("load_in_4bit"):
110-
# return None
111-
# return BitsAndBytesConfig(
112-
# load_in_4bit=True,
113-
# bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"),
114-
# bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16),
115-
# bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True),
116-
# )
117-
118108
def configure_model_kwargs(self) -> Dict[str, Any]:
119109
"""Hook for subclasses to tweak HF `.from_pretrained` kwargs."""
120110

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DTYPE_MAPPING = {"fp16": "float16", "bf16": "bfloat16"}

0 commit comments

Comments
 (0)