Skip to content

Commit 492838e

Browse files
quic-swatiaSwati Allabadi
andauthored
[QEff. Finetuning]: Tests for Pipeline Parallelism and updated documentation (#893)
1) Added unit test cases for Pipeline Parallelism 2) Added documentation on how to run these tests 3) Created a constants file Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com> Co-authored-by: Swati Allabadi <sallabad@qti.qualcomm.com>
1 parent bafc879 commit 492838e

File tree

5 files changed

+1041
-11
lines changed

5 files changed

+1041
-11
lines changed

QEfficient/finetune/experimental/core/config_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers.hf_argparser import HfArgumentParser
2222

2323
from QEfficient.finetune.experimental.core.logger import Logger
24+
from QEfficient.finetune.experimental.core.utils import constants
2425
from QEfficient.finetune.experimental.core.utils.dist_utils import is_main_process
2526
from QEfficient.utils.device_utils import is_nsp_free
2627

@@ -855,7 +856,7 @@ def get_model_config(self) -> Dict[str, Any]:
855856
training_dtype = training_config.get("torch_dtype")
856857
if training_dtype:
857858
# Convert from training format (fp16/bf16) to model format (float16/bfloat16)
858-
dtype_mapping = {"fp16": "float16", "bf16": "bfloat16"}
859+
dtype_mapping = dtype_mapping = constants.DTYPE_MAPPING
859860
model_config["torch_dtype"] = dtype_mapping.get(training_dtype, "auto")
860861

861862
return model_config

QEfficient/finetune/experimental/core/model.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,6 @@ def _resolve_auto_class(auto_class_name: str) -> Type:
105105
)
106106
return getattr(transformers, auto_class_name)
107107

108-
# def _build_quant_config(self) -> Optional[BitsAndBytesConfig]:
109-
# if not self.model_kwargs.get("load_in_4bit"):
110-
# return None
111-
# return BitsAndBytesConfig(
112-
# load_in_4bit=True,
113-
# bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"),
114-
# bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16),
115-
# bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True),
116-
# )
117-
118108
def configure_model_kwargs(self) -> Dict[str, Any]:
119109
"""Hook for subclasses to tweak HF `.from_pretrained` kwargs."""
120110

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DTYPE_MAPPING = {"fp16": "float16", "bf16": "bfloat16"}

0 commit comments

Comments
 (0)