File tree Expand file tree Collapse file tree 3 files changed +24
-7
lines changed
compressors/model_compressors Expand file tree Collapse file tree 3 files changed +24
-7
lines changed Original file line number Diff line number Diff line change 1+ from transformers import AutoModelForCausalLM , AutoTokenizer
2+ from llmcompressor .utils import dispatch_for_generation
3+
4+ #MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-w4a16-asym-awq-e2e"
5+ MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4"
6+
7+ model = AutoModelForCausalLM .from_pretrained (MODEL_ID , torch_dtype = "auto" )
8+ tokenizer = AutoTokenizer .from_pretrained (MODEL_ID )
9+
10+ print ("========== SAMPLE GENERATION ==============" )
11+ dispatch_for_generation (model )
12+ input_ids = tokenizer ("Hello my name is" , return_tensors = "pt" ).input_ids .to (model .device )
13+ output = model .generate (input_ids , max_new_tokens = 100 )
14+ print (tokenizer .decode (output [0 ]))''
15+ print ("==========================================\n \n " )
Original file line number Diff line number Diff line change @@ -338,10 +338,10 @@ def __init__(
338338
339339 self .quantization_compressor = {}
340340 for format in self .compression_formats :
341- self .quantization_compressor [
342- format
343- ] = BaseCompressor . load_from_registry (
344- format , config = quantization_config
341+ self .quantization_compressor [format ] = (
342+ BaseCompressor . load_from_registry (
343+ format , config = quantization_config
344+ )
345345 )
346346
347347 def get_missing_module_keys (self , model : Module ) -> List [str ]:
Original file line number Diff line number Diff line change @@ -116,9 +116,11 @@ def calculate_qparams(
116116 # 4. Update any 0s with small values to
117117 # prevent div by 0
118118 eps = _get_dtype_eps (
119- dtype = quantization_args .scale_dtype
120- if quantization_args .scale_dtype is not None
121- else scales .dtype
119+ dtype = (
120+ quantization_args .scale_dtype
121+ if quantization_args .scale_dtype is not None
122+ else scales .dtype
123+ )
122124 )
123125 scales = torch .where (
124126 scales == 0 ,
You can’t perform that action at this time.
0 commit comments