Skip to content

Commit 19ebc50

Browse files
committed
quality
Signed-off-by: shanjiaz <[email protected]>
1 parent fd73ced commit 19ebc50

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

nvfp4_decompress.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
from llmcompressor.utils import dispatch_for_generation
3+
4+
#MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-w4a16-asym-awq-e2e"
5+
MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4"
6+
7+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
8+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9+
10+
print("========== SAMPLE GENERATION ==============")
11+
dispatch_for_generation(model)
12+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
13+
output = model.generate(input_ids, max_new_tokens=100)
14+
print(tokenizer.decode(output[0]))''
15+
print("==========================================\n\n")

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,10 @@ def __init__(
338338

339339
self.quantization_compressor = {}
340340
for format in self.compression_formats:
341-
self.quantization_compressor[
342-
format
343-
] = BaseCompressor.load_from_registry(
344-
format, config=quantization_config
341+
self.quantization_compressor[format] = (
342+
BaseCompressor.load_from_registry(
343+
format, config=quantization_config
344+
)
345345
)
346346

347347
def get_missing_module_keys(self, model: Module) -> List[str]:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,11 @@ def calculate_qparams(
116116
# 4. Update any 0s with small values to
117117
# prevent div by 0
118118
eps = _get_dtype_eps(
119-
dtype=quantization_args.scale_dtype
120-
if quantization_args.scale_dtype is not None
121-
else scales.dtype
119+
dtype=(
120+
quantization_args.scale_dtype
121+
if quantization_args.scale_dtype is not None
122+
else scales.dtype
123+
)
122124
)
123125
scales = torch.where(
124126
scales == 0,

0 commit comments

Comments
 (0)