@@ -75,20 +75,64 @@ def get_min_capability(cls) -> int:
75
75
def get_config_filenames (cls ) -> list [str ]:
76
76
return ["hf_quant_config.json" ]
77
77
78
+ @classmethod
79
+ def override_quantization_method (
80
+ cls , hf_quant_cfg , user_quant ) -> Optional [QuantizationMethods ]:
81
+ """Detect if this ModelOpt config should be used based on
82
+ quantization config."""
83
+
84
+ if hf_quant_cfg is None :
85
+ return None
86
+
87
+ # Use the community standard 'quant_method'
88
+ quant_method = hf_quant_cfg .get ("quant_method" , "" ).lower ()
89
+
90
+ # Only proceed if the method is explicitly "modelopt"
91
+ if quant_method != "modelopt" :
92
+ return None
93
+
94
+ # Look for ModelOpt-specific config structure
95
+ if "quantization" in hf_quant_cfg :
96
+ quant_config = hf_quant_cfg ["quantization" ]
97
+ if isinstance (quant_config , dict ):
98
+ quant_algo = quant_config .get ("quant_algo" , "" )
99
+ if "FP8" in quant_algo :
100
+ return "modelopt"
101
+ else :
102
+ # Check for compressed-tensors style config with specific quant_algo
103
+ quant_algo = hf_quant_cfg .get ("quant_algo" , "" )
104
+ if isinstance (quant_algo , str ) and "FP8" in quant_algo :
105
+ return "modelopt"
106
+
107
+ return None
108
+
78
109
@classmethod
79
110
def from_config (cls , config : dict [str , Any ]) -> "ModelOptFp8Config" :
80
- quant_config = cls .get_from_keys (config , ["quantization" ])
81
- quant_method = quant_config ["quant_algo" ]
82
- kv_cache_quant_method = cls .get_from_keys (
83
- config , ["quantization" ]).get ("kv_cache_quant_algo" )
84
- exclude_modules = cls .get_from_keys (
85
- config , ["quantization" ]).get ("exclude_modules" )
111
+ # Handle both ModelOpt format and compressed-tensors style format
112
+ if "quantization" in config :
113
+ # ModelOpt format: {"quantization": {"quant_algo": "..."}}
114
+ quant_config = cls .get_from_keys (config , ["quantization" ])
115
+ if not isinstance (quant_config , dict ):
116
+ raise ValueError (
117
+ "Expected 'quantization' to be a dictionary in config" )
118
+ quant_method = quant_config .get ("quant_algo" , "" )
119
+ if not quant_method :
120
+ raise ValueError ("Missing 'quant_algo' in quantization config" )
121
+ kv_cache_quant_method = quant_config .get ("kv_cache_quant_algo" )
122
+ exclude_modules = quant_config .get ("exclude_modules" )
123
+ else :
124
+ # Compressed-tensors style format:
125
+ # {"quant_algo": "...", "quant_method": "modelopt"}
126
+ quant_method = config .get ("quant_algo" , "" )
127
+ kv_cache_quant_method = config .get ("kv_cache_quant_algo" )
128
+ exclude_modules = config .get ("exclude_modules" )
86
129
87
130
if quant_method not in QUANT_ALGOS :
88
- raise ValueError (f"ModelOpt currently only supports: { QUANT_ALGOS } "
89
- " quantizations in vLLM. Please check the "
90
- "`hf_quant_config.json` file for your model's "
91
- "quant configuration." )
131
+ raise ValueError (
132
+ f"ModelOpt currently only supports: { QUANT_ALGOS } "
133
+ "quantizations in vLLM. Please check the "
134
+ "`hf_quant_config.json` file for your model's "
135
+ "quant configuration." )
92
136
is_checkpoint_fp8_serialized = ("FP8" in quant_method )
93
137
94
138
return cls (is_checkpoint_fp8_serialized , kv_cache_quant_method ,
@@ -434,7 +478,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
434
478
def __init__ (
435
479
self ,
436
480
is_checkpoint_nvfp4_serialized : bool ,
437
- kv_cache_quant_algo : str ,
481
+ kv_cache_quant_algo : Optional [ str ] ,
438
482
exclude_modules : list [str ],
439
483
group_size : int = 16 ,
440
484
) -> None :
@@ -465,24 +509,138 @@ def get_min_capability(cls) -> int:
465
509
def get_config_filenames (cls ) -> list [str ]:
466
510
return ["hf_quant_config.json" ]
467
511
512
+ @classmethod
513
+ def override_quantization_method (
514
+ cls , hf_quant_cfg , user_quant ) -> Optional [QuantizationMethods ]:
515
+ """Detect if this ModelOpt FP4 config should be used based on
516
+ quantization config."""
517
+ if hf_quant_cfg is None :
518
+ return None
519
+
520
+ # Use the community standard 'quant_method'
521
+ quant_method = hf_quant_cfg .get ("quant_method" , "" ).lower ()
522
+
523
+ # Only proceed if the method is explicitly "modelopt"
524
+ if quant_method != "modelopt" :
525
+ return None
526
+
527
+ # Look for ModelOpt-specific config structure
528
+ if "quantization" in hf_quant_cfg :
529
+ quant_config = hf_quant_cfg ["quantization" ]
530
+ if isinstance (quant_config , dict ):
531
+ quant_algo = quant_config .get ("quant_algo" , "" )
532
+ if "NVFP4" in quant_algo :
533
+ return "modelopt_fp4"
534
+ else :
535
+ # Check for compressed-tensors style config with specific
536
+ # quant_algo field
537
+ quant_algo = hf_quant_cfg .get ("quant_algo" , "" )
538
+ if isinstance (quant_algo , str ) and "FP4" in quant_algo .upper ():
539
+ return "modelopt_fp4"
540
+
541
+ return None
542
+
468
543
@classmethod
469
544
def from_config (cls , config : dict [str , Any ]) -> "ModelOptNvFp4Config" :
470
- quant_config = cls .get_from_keys (config , ["quantization" ])
471
- quant_method = quant_config ["quant_algo" ]
545
+ # Handle both traditional ModelOpt format and compressed-tensors
546
+ # style format
547
+ if "quantization" in config :
548
+ # Traditional ModelOpt format:
549
+ # {"quantization": {"quant_algo": "..."}}
550
+ quant_config = cls .get_from_keys (config , ["quantization" ])
551
+ if not isinstance (quant_config , dict ):
552
+ raise ValueError (
553
+ "Expected 'quantization' to be a dictionary in config" )
554
+
555
+ quant_method = quant_config .get ("quant_algo" , "" )
556
+ if not quant_method :
557
+ raise ValueError ("Missing 'quant_algo' in quantization config" )
558
+
559
+ # Handle kv_cache_quant_algo with proper type validation
560
+ kv_cache_quant_algo_raw = quant_config .get ("kv_cache_quant_algo" )
561
+ if kv_cache_quant_algo_raw is None :
562
+ # No KV cache quantization by default
563
+ kv_cache_quant_algo = None
564
+ elif isinstance (kv_cache_quant_algo_raw , str ):
565
+ kv_cache_quant_algo = kv_cache_quant_algo_raw
566
+ else :
567
+ raise ValueError (f"kv_cache_quant_algo must be a string, got "
568
+ f"{ type (kv_cache_quant_algo_raw )} " )
569
+
570
+ # Handle group_size with proper type validation
571
+ group_size_raw = quant_config .get ("group_size" )
572
+ if group_size_raw is None :
573
+ group_size = 16 # Default value
574
+ elif isinstance (group_size_raw , int ):
575
+ group_size = group_size_raw
576
+ else :
577
+ try :
578
+ group_size = int (group_size_raw )
579
+ except (ValueError , TypeError ):
580
+ raise ValueError (f"group_size must be an integer, got "
581
+ f"{ type (group_size_raw )} " ) from None
582
+
583
+ exclude_modules = quant_config .get ("exclude_modules" , [])
584
+ if not isinstance (exclude_modules , list ):
585
+ raise ValueError (f"exclude_modules must be a list, got "
586
+ f"{ type (exclude_modules )} " )
587
+ else :
588
+ # Compressed-tensors style format:
589
+ # {"quant_algo": "...", "quant_method": "modelopt"}
590
+ quant_method = config .get ("quant_algo" , "" )
591
+
592
+ # Handle kv_cache_quant_algo with proper type validation
593
+ kv_cache_quant_algo_raw = config .get ("kv_cache_quant_algo" )
594
+ if kv_cache_quant_algo_raw is None :
595
+ # No KV cache quantization by default
596
+ kv_cache_quant_algo = None
597
+ elif isinstance (kv_cache_quant_algo_raw , str ):
598
+ kv_cache_quant_algo = kv_cache_quant_algo_raw
599
+ else :
600
+ raise ValueError (f"kv_cache_quant_algo must be a string, got "
601
+ f"{ type (kv_cache_quant_algo_raw )} " )
602
+
603
+ # Handle group_size with proper type validation
604
+ group_size_raw = config .get ("group_size" )
605
+ if group_size_raw is None :
606
+ group_size = 16 # Default value
607
+ elif isinstance (group_size_raw , int ):
608
+ group_size = group_size_raw
609
+ else :
610
+ try :
611
+ group_size = int (group_size_raw )
612
+ except (ValueError , TypeError ):
613
+ raise ValueError (f"group_size must be an integer, got "
614
+ f"{ type (group_size_raw )} " ) from None
615
+
616
+ exclude_modules = config .get ("exclude_modules" , [])
617
+ if not isinstance (exclude_modules , list ):
618
+ raise ValueError (f"exclude_modules must be a list, got "
619
+ f"{ type (exclude_modules )} " )
620
+
472
621
if quant_method not in QUANT_ALGOS :
473
- raise ValueError (f"ModelOpt currently only supports: { QUANT_ALGOS } "
474
- " quantizations in vLLM. Please check the "
475
- "`hf_quant_config.json` file for your model's "
476
- "quant configuration." )
622
+ raise ValueError (
623
+ f"ModelOpt currently only supports: { QUANT_ALGOS } "
624
+ "quantizations in vLLM. Please check the "
625
+ "`hf_quant_config.json` file for your model's "
626
+ "quant configuration." )
477
627
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method )
478
- if ("group_size" and "kv_cache_quant_algo"
479
- and "exclude_modules" ) not in quant_config :
480
- raise ValueError ("NVFP4 quantization requires group size and "
481
- "kv_cache_quant_algo specified in "
482
- "hf_quant_config.json" )
483
- kv_cache_quant_algo = quant_config ["kv_cache_quant_algo" ]
484
- group_size = quant_config ["group_size" ]
485
- exclude_modules = quant_config ["exclude_modules" ]
628
+
629
+ # For FP4, these fields are required
630
+ if is_checkpoint_nvfp4_serialized and "quantization" in config :
631
+ # Check if required fields are present in the quantization config
632
+ quant_config = config ["quantization" ]
633
+ required_fields = [
634
+ "group_size" , "kv_cache_quant_algo" , "exclude_modules"
635
+ ]
636
+ missing_fields = [
637
+ field for field in required_fields if field not in quant_config
638
+ ]
639
+ if missing_fields :
640
+ raise ValueError (
641
+ f"NVFP4 quantization requires the following fields in "
642
+ f"hf_quant_config.json: { missing_fields } " )
643
+
486
644
return cls (is_checkpoint_nvfp4_serialized , kv_cache_quant_algo ,
487
645
exclude_modules , group_size )
488
646
0 commit comments