Skip to content

feat: add Qwen3.5 MoE calibration module#2383

Closed
Sehyo wants to merge 12 commits intovllm-project:mainfrom
Sehyo:feat/qwen3-5-moe-calibration
Closed

feat: add Qwen3.5 MoE calibration module#2383
Sehyo wants to merge 12 commits intovllm-project:mainfrom
Sehyo:feat/qwen3-5-moe-calibration

Conversation

@Sehyo
Copy link

@Sehyo Sehyo commented Feb 18, 2026

Summary

  • Add CalibrationQwen3_5MoeSparseMoeBlock calibration module that unfuses Qwen3.5's 3D fused expert parameters into individual Qwen3_5MoeMLP modules with nn.Linear layers, enabling NVFP4 quantization of expert weights
  • Register the module in modeling/__init__.py
  • Add NVFP4 quantization example script for Qwen/Qwen3.5-397B-A17B

Details

Qwen3.5 MoE (Qwen3_5MoeSparseMoeBlock) stores all expert weights in fused 3D nn.Parameter tensors (gate_up_proj: [num_experts, 2*intermediate, hidden], down_proj: [num_experts, hidden, intermediate]). The calibration module unfuses these into individual MLP modules so targets="Linear" can match and quantize them.

The implementation follows the same pattern as CalibrateQwen3VLMoeTextSparseMoeBlock with is_permanent=True, and includes disable_onloading() for safe CPU access to offloaded parameters on large models.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Sehyo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a specialized calibration module for Qwen3.5 Mixture-of-Experts (MoE) models, designed to facilitate efficient NVFP4 quantization of their expert weights. By dynamically restructuring the MoE block to expose individual expert layers as standard linear modules, it enables the application of fine-grained quantization techniques. A new example script demonstrates this process, ensuring broader compatibility and optimized performance for these large language models.

Highlights

  • Qwen3.5 MoE Calibration Module: Introduced CalibrationQwen3_5MoeSparseMoeBlock to enable NVFP4 quantization for Qwen3.5 MoE models.
  • Expert Parameter Unfusing: This new module unfuses Qwen3.5's 3D fused expert parameters into individual nn.Linear layers, making them targetable for quantization.
  • Module Registration: The new calibration module has been registered in modeling/__init__.py.
  • NVFP4 Quantization Example: An example script (qwen3_5_moe_example.py) was added to demonstrate NVFP4 quantization for the Qwen/Qwen3.5-397B-A17B model.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/quantization_w4a4_fp4/qwen3_5_moe_example.py
    • Added a new example script for NVFP4 quantization of Qwen3.5 MoE models.
  • src/llmcompressor/modeling/init.py
    • Imported CalibrationQwen3_5MoeSparseMoeBlock.
    • Registered the new Qwen3.5 MoE calibration module.
  • src/llmcompressor/modeling/qwen3_5_moe.py
    • Added CalibrationQwen3_5MoeSparseMoeBlock to unfuse 3D expert parameters into individual nn.Linear modules for quantization.
    • Implemented SequentialQwen3_5MoeExperts to manage the unfused expert layers.
    • Included logic to handle offloaded parameters safely during unfusing.
Activity
  • No activity (comments, reviews, etc.) was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the documentation Improvements or additions to documentation label Feb 18, 2026
@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a calibration module for Qwen3.5 MoE models, enabling NVFP4 quantization. The changes include the core module implementation, its registration within the modeling package, and a comprehensive example script demonstrating its usage on a large-scale model. The implementation correctly unfuses expert weights into individual nn.Linear layers, which is crucial for quantization. The approach of using disable_onloading to handle large model weights on the CPU is well-considered. I have identified one potential issue in the forward pass logic that could lead to errors for MoE models configured with top_k=1, and I have provided a suggestion to address it.

@Sehyo Sehyo force-pushed the feat/qwen3-5-moe-calibration branch from 83c7bd8 to 1d428f9 Compare February 18, 2026 11:14
@Sehyo
Copy link
Author

Sehyo commented Feb 18, 2026

Requesting review alt. ready tag and enhancement tag.
@dsikka @kylesayrs

@Sehyo
Copy link
Author

Sehyo commented Feb 18, 2026

Quantized version with this PR:
https://huggingface.co/Sehyo/Qwen3.5-397B-A17B-NVFP4

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good - thank you!

@dsikka dsikka added ready When a PR is ready for review qwen For any PR / issue related to Qwen support nvfp4 For any PR / issue related to NVFP4 support labels Feb 18, 2026
@mergify
Copy link
Contributor

mergify bot commented Feb 18, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@aabbccddwasd
Copy link

keep getting RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)

@Sehyo
Copy link
Author

Sehyo commented Feb 19, 2026

keep getting RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)

Is this an error from VLLM?
Can you please share GPU setup, vLLM version and the trace?
Is it same problem as in vLLM issue #33544?

@Sehyo
Copy link
Author

Sehyo commented Feb 19, 2026

I have detected an issue in the current upstream version of VLLM which causes the Qwen3.5 NVFP4 quant to fail.

in Qwen 3.5 Gated Delta Net, we have some fused / merged projections:
in_proj_qkvz = Q + K + V + Z
in_proj_ba = B + A

and VLLM does fusing like:
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),

.. Which assumes plain weight tensors which are concatable.. But NVFP4 format stores weights in weight_packed (4bit packed) way. --> Fused weights are garbage

I am currently trying to write a fix for this, if I succeed to get it working will submit a PR to vllm repo as well.

@dsikka
Copy link
Collaborator

dsikka commented Feb 19, 2026

I have detected an issue in the current upstream version of VLLM which causes the Qwen3.5 NVFP4 quant to fail.

in Qwen 3.5 Gated Delta Net, we have some fused / merged projections: in_proj_qkvz = Q + K + V + Z in_proj_ba = B + A

and VLLM does fusing like: ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), ("in_proj_qkvz", "in_proj_z", 3), ("in_proj_ba", "in_proj_b", 0), ("in_proj_ba", "in_proj_a", 1),

.. Which assumes plain weight tensors which are concatable.. But NVFP4 format stores weights in weight_packed (4bit packed) way. --> Fused weights are garbage

I am currently trying to write a fix for this, if I succeed to get it working will submit a PR to vllm repo as well.

If we skip quantizing the linear attn layers, wont this issue be resolved?
It seems like you skipped them in your ignore list

@dsikka
Copy link
Collaborator

dsikka commented Feb 19, 2026

Do you mind adding a test similar to the tests in this folder: https://github.com/vllm-project/llm-compressor/tree/main/tests/llmcompressor/modeling

@Sehyo
Copy link
Author

Sehyo commented Feb 19, 2026

I have detected an issue in the current upstream version of VLLM which causes the Qwen3.5 NVFP4 quant to fail.
in Qwen 3.5 Gated Delta Net, we have some fused / merged projections: in_proj_qkvz = Q + K + V + Z in_proj_ba = B + A
and VLLM does fusing like: ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), ("in_proj_qkvz", "in_proj_z", 3), ("in_proj_ba", "in_proj_b", 0), ("in_proj_ba", "in_proj_a", 1),
.. Which assumes plain weight tensors which are concatable.. But NVFP4 format stores weights in weight_packed (4bit packed) way. --> Fused weights are garbage
I am currently trying to write a fix for this, if I succeed to get it working will submit a PR to vllm repo as well.

If we skip quantizing the linear attn layers, wont this issue be resolved? It seems like you skipped them in your ignore list

Yes for those layers it does not matter.
However I noticed in VLLM that load_weights have no mapping between unfused checkpoint weights (the Qwen 3.5 HF checkpoint stores unfused names: in_proj_qkv.weight, in_proj_z.weight etc).. Still looking into it.
-- Edit: Never mind, noticed the issue I was running into was fixed in a commit in the latest vllm nightly.

@Sehyo
Copy link
Author

Sehyo commented Feb 19, 2026

Do you mind adding a test similar to the tests in this folder: https://github.com/vllm-project/llm-compressor/tree/main/tests/llmcompressor/modeling

Sure, will do it!

@mergify mergify bot removed the quality-failed label Feb 20, 2026
@Sehyo Sehyo force-pushed the feat/qwen3-5-moe-calibration branch 2 times, most recently from 642ba83 to d030961 Compare February 20, 2026 09:51
@Sehyo
Copy link
Author

Sehyo commented Feb 20, 2026

@dsikka Tests have been added.

@Sehyo
Copy link
Author

Sehyo commented Feb 22, 2026

Review Request

@Sehyo
Copy link
Author

Sehyo commented Feb 25, 2026

Came to my attention that the MTP modules are dropped from the quant. I am away until sunday but can fix it then.

@JartX
Copy link
Contributor

JartX commented Mar 3, 2026

@Sehyo I would switch it to W4A16 Scheme; the group size is for it to work on Exllama in my RDNA3

@BenasdTW
Copy link

BenasdTW commented Mar 3, 2026

@Sehyo Hi, I’ve encountered a couple of issues while running a modified version of your example code.

Modification to the quantization script:

scheme_0 = FP8_DYNAMIC
scheme_0["targets"] = ["re:.*self_attn.o_proj", "re:.*linear_attn.in_proj_qkv", "re:.*linear_attn.in_proj_z", "re:.*linear_attn.out_proj"]
scheme_1 = NVFP4
scheme_1["targets"] = ["re:.*self_attn.(q|k|v)_proj", "re:.*mlp.experts.*.*_proj"]

ignore = ["re:.*lm_head", "re:visual.*", "re:model.visual.*", "re:.*mlp.gate$", "re:.*norm.*", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$", "re:.*mtp.*", "re:.*conv1d.*", "re:.*in_proj_a*", "re:.*in_proj_b*", "re:.*in_proj_c*"]
recipe = QuantizationModifier(
    config_groups={"group_0": scheme_0, "group_1": scheme_1}, ignore=ignore
)

Expected behavior: self_attn.o_proj, linear_attn.in_proj_qkv and linear_attn.in_proj_z should be quantized to fp8. self_attn.(q|k|v)_proj and mlp quantized to NVFP4.

Result: Only self_attn.o_proj is quantized to fp8, linear_attn.in_proj_qkv and linear_attn.in_proj_z remain unquantized. While the NVFP4 is working as expected.

Another issue: the exported tokenizer metadata appears to use an unexpected class:

"tokenizer_class": "TokenizersBackend",

@dsikka dsikka mentioned this pull request Mar 4, 2026
@phaelon74
Copy link
Contributor

@Sehyo Hi, I’ve encountered a couple of issues while running a modified version of your example code.

Modification to the quantization script:

scheme_0 = FP8_DYNAMIC
scheme_0["targets"] = ["re:.*self_attn.o_proj", "re:.*linear_attn.in_proj_qkv", "re:.*linear_attn.in_proj_z", "re:.*linear_attn.out_proj"]
scheme_1 = NVFP4
scheme_1["targets"] = ["re:.*self_attn.(q|k|v)_proj", "re:.*mlp.experts.*.*_proj"]

ignore = ["re:.*lm_head", "re:visual.*", "re:model.visual.*", "re:.*mlp.gate$", "re:.*norm.*", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$", "re:.*mtp.*", "re:.*conv1d.*", "re:.*in_proj_a*", "re:.*in_proj_b*", "re:.*in_proj_c*"]
recipe = QuantizationModifier(
    config_groups={"group_0": scheme_0, "group_1": scheme_1}, ignore=ignore
)

Expected behavior: self_attn.o_proj, linear_attn.in_proj_qkv and linear_attn.in_proj_z should be quantized to fp8. self_attn.(q|k|v)_proj and mlp quantized to NVFP4.

Result: Only self_attn.o_proj is quantized to fp8, linear_attn.in_proj_qkv and linear_attn.in_proj_z remain unquantized. While the NVFP4 is working as expected.

Another issue: the exported tokenizer metadata appears to use an unexpected class:

"tokenizer_class": "TokenizersBackend",

There was and still may be, an issue using Mixed Precision with NVFP4 in VLLM. Be aware of that, as that may be occurring here.

I closed my PR, as I didn't see yours @Sehyo . Your code was very close to mine, and your MTP handling is solid for peeps who turn it on. Thanks for Submitting this.

@BenasdTW
Copy link

BenasdTW commented Mar 4, 2026

There was and still may be, an issue using Mixed Precision with NVFP4 in VLLM. Be aware of that, as that may be occurring here.

@phaelon74 Thanks for the information! I’ll open a separate issue to discuss this, since it seems unrelated to this PR. I wonder if this is specific to the new linear_attn module, because self_attn.o_proj is being quantized correctly.

Edit: I found the issue. Turns out the regex wasn’t matching in my script.
Fixed version:

scheme_0 = FP8_DYNAMIC
scheme_0["targets"] = [
    "re:.*self_attn.o_proj$",
    "re:.*linear_attn.in_proj_qkv$",
    "re:.*linear_attn.in_proj_z$",
    "re:.*linear_attn.out_proj$",
]
scheme_1 = NVFP416
scheme_1["targets"] = [
    "re:.*self_attn.(q|k|v)_proj$",
    "re:.*mlp.experts.*.*_proj$",
]
ignore = ["re:.*lm_head", "re:visual.*", "re:model.visual.*", "re:.*mlp.gate$", "re:.*norm.*", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$", "re:.*mtp.*", "re:.*conv1d.*", "re:.*in_proj_a+", "re:.*in_proj_b+", "re:.*in_proj_c+"]
recipe = QuantizationModifier(
    config_groups={"group_0": scheme_0, "group_1": scheme_1}, ignore=ignore
)

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks fine but I dont quite understand why we need an updated regex pattern, _update_config_expanded_ignore, or _graft_extra_weights?i I think generally, if we want to expand regex mapping, that shoud be done in a follow-up PR as it is not specific to Qwen3.5

I am able to generate quantized checkpoints without this

# by regex (e.g. MoE router modules that aren't nn.Linear).
# Store expanded names on the model so the save wrapper can ensure
# they appear in config.json.
regex_patterns = [p for p in self.ignore if p.startswith("re:")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why you need this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not have this in mine, and mine quanted and loaded successfully in VLLM, so would love to know as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sehyo can you explain why this is required?

@mergify
Copy link
Contributor

mergify bot commented Mar 5, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@Sehyo
Copy link
Author

Sehyo commented Mar 7, 2026

Overall this looks fine but I dont quite understand why we need an updated regex pattern, _update_config_expanded_ignore, or _graft_extra_weights?i I think generally, if we want to expand regex mapping, that shoud be done in a follow-up PR as it is not specific to Qwen3.5

I am able to generate quantized checkpoints without this

Graft extra weights is for re-adding MTP weights back in as they get dropped.

@dsikka
Copy link
Collaborator

dsikka commented Mar 7, 2026

Overall this looks fine but I dont quite understand why we need an updated regex pattern, _update_config_expanded_ignore, or _graft_extra_weights?i I think generally, if we want to expand regex mapping, that shoud be done in a follow-up PR as it is not specific to Qwen3.5
I am able to generate quantized checkpoints without this

Graft extra weights is for re-adding MTP weights back in as they get dropped.

@Sehyo I think we want to do this at the end when we're saving the checkpoint, not in the middle of calibration as it does not impact quantization.

Do you mind also resolving the quality issues?

@mergify mergify bot removed the quality-failed label Mar 7, 2026
@mergify
Copy link
Contributor

mergify bot commented Mar 7, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@paulplay-pm
Copy link

Hi @Sehyo, could you try to use structured_outputs json with the model?

The Script to extract data of json

import base64
import fitz
import json
import requests
import sys
import time
from pathlib import Path
from PIL import Image
from io import BytesIO
from typing import List, Optional, Literal
from pydantic import BaseModel
from qwen_vl_utils import smart_resize

TARGET_DPI = 150
PYMUPDF_BASE_DPI = 72
PATCH_MULTIPLE = 32 
MAX_SIDE_PX = 3000

VLLM_URL = "http://192.168.1.75/v1/chat/completions"
MODEL = "QWEN3.5"

class Vendor(BaseModel):
    name: Optional[str]
    address: Optional[str]
    cif_nif: Optional[str]
    confidence: Literal["confident", "partial", "ambiguous"]

class Client(BaseModel):
    name: Optional[str]
    address: Optional[str]
    cif_nif: Optional[str]
    confidence: Literal["confident", "partial", "ambiguous"]

class Dates(BaseModel):
    issue_date: Optional[str]
    due_date: Optional[str]

class Amounts(BaseModel):
    subtotal: Optional[float]
    total_discount: Optional[float]
    tax_rate: Optional[float]
    tax_amount: Optional[float]
    total: Optional[float]
    currency: Optional[str]

class LineItem(BaseModel):
    description: str
    quantity: Optional[float]
    unit_price: Optional[float]
    discount_percentage: Optional[float]
    discount_amount: Optional[float]
    total: Optional[float]
    confidence: Literal["confident", "partial", "ambiguous"]

class PaymentInfo(BaseModel):
    iban: Optional[str]
    payment_method: Optional[str]

class FieldWithIssue(BaseModel):
    field_name: str
    issue_type: Literal["not_found", "partial", "conflicting", "ambiguous"]
    notes: Optional[str]

class InvoiceExtraction(BaseModel):
    extraction_status: Literal["success", "partial", "failed"]
    failure_reason: Optional[str]
    invoice_number: Optional[str]
    vendor: Optional[Vendor]
    client: Optional[Client]
    dates: Optional[Dates]
    amounts: Optional[Amounts]
    line_items: Optional[List[LineItem]]
    payment_info: Optional[PaymentInfo]
    fields_with_issues: List[FieldWithIssue]

SYSTEM_PROMPT = """\
You are an expert invoice extraction system. Your sole task is to extract information EXPLICITLY visible in the document.

STRICT RULES:
1. LINE ITEM SEPARATION: Extract EVERY SINGLE ROW as a separate line item object. NEVER merge distinct products, services, or descriptions into a single line item. If a row only contains a description, extract it as a separate line item with null amounts.
2. NUMBER PARSING: Spanish/European formats use ',' for decimals and '.' for thousands.
   - A quantity of '1,000' is ALMOST ALWAYS 1 unit (1.0), NOT one thousand.
   - A price of '59,0000' is 59.0.
   - An amount of '1.260,00' is 1260.0.
   - Convert these strictly to standard JSON numbers (floats).
3. MATHEMATICAL VALIDATION: Verify that (Quantity * Unit Price) - Discount = Total for EACH line.
4. DISCOUNT HANDLING:
   - 'discount_percentage': Use this ONLY if there is a '%' sign or the column header explicitly indicates a percentage.
   - 'discount_amount': Use this ONLY if it is a direct monetary deduction or absolute value.

Respond ONLY with valid JSON conforming to the provided schema.
"""

USER_PROMPT_PREFIX = """\
Analyze the provided invoice images (which may span multiple pages) and extract all fields according to the strict JSON schema.
"""

def pad_to_multiple(img: Image.Image, multiple: int = PATCH_MULTIPLE) -> Image.Image:
    w, h = img.size
    new_w = ((w + multiple - 1) // multiple) * multiple
    new_h = ((h + multiple - 1) // multiple) * multiple
    if new_w == w and new_h == h:
        return img
    canvas = Image.new("RGB", (new_w, new_h), (255, 255, 255))
    canvas.paste(img, (0, 0))
    return canvas

def clamp_resolution(img: Image.Image, max_side: int = MAX_SIDE_PX) -> Image.Image:
    w, h = img.size
    if max(w, h) <= max_side:
        return img
    scale = max_side / max(w, h)
    return img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)

def img_to_b64(img: Image.Image) -> str:
    buf = BytesIO()
    img.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode("utf-8")

def pdf_to_images(pdf_path: Path, dpi: int = TARGET_DPI) -> list[Image.Image]:
    zoom = dpi / PYMUPDF_BASE_DPI
    mat = fitz.Matrix(zoom, zoom)
    doc = fitz.open(pdf_path)
    pages = []
    for page in doc:
        pix = page.get_pixmap(matrix=mat, colorspace=fitz.csRGB, alpha=False)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        img = clamp_resolution(img)
        pages.append(img)
    doc.close()
    return pages

def build_image_content(pages: list[Image.Image]) -> list[dict]:
    content = []
    min_pixels = 512 * 28 * 28
    max_pixels = 4608 * 28 * 28
    
    for img in pages:
        w, h = img.size
        new_h, new_w = smart_resize(h, w, min_pixels=min_pixels, max_pixels=max_pixels, factor=28)
        resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
        padded_img = pad_to_multiple(resized_img, multiple=28)
        
        content.append(
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/png;base64,{img_to_b64(padded_img)}"},
            }
        )
    content.append({"type": "text", "text": USER_PROMPT_PREFIX})
    return content

def call_vllm(payload: dict) -> dict:
    time.sleep(2)
    response = requests.post(
        VLLM_URL,
        json=payload,
        headers={"Content-Type": "application/json"},
    )
    response.raise_for_status()
    return response.json()

def extract_invoice(pdf_path: str) -> dict:
    pages = pdf_to_images(Path(pdf_path))
    content = build_image_content(pages)

    payload = {
        "model": MODEL,
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": content},
        ],
        "temperature": 0.3,
#	"top_k": 20,
#	"min_p": 0.0,
#	"top_p": 0.8,
#	"presence_penalty": 1.5,
#	"repetition_penalty": 1.0,
        "chat_template_kwargs": {"enable_thinking": False},
        "response_format": {
            "type": "json_schema",
            "json_schema": {
                "name": "invoice_extraction",
                "schema": InvoiceExtraction.model_json_schema(),
                "strict": True,
            },
        },
    }

    response_data = call_vllm(payload)
    message = response_data["choices"][0]["message"]

    content_str = message.get("content")
    if content_str is None:
        raise ValueError("None")

    return json.loads(content_str)

if __name__ == "__main__":
    if len(sys.argv) < 2:
        sys.exit(1)

    result = extract_invoice(sys.argv[1])
    print(json.dumps(result, indent=2, ensure_ascii=False))

It is my quant script, based on your branch/PR but using GPTQ.

import os
import shutil

import torch
from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationScheme,
    QuantizationStrategy,
    QuantizationType,
)
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import snapshot_download
from transformers import AutoProcessor, AutoTokenizer, Qwen3_5MoeForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier

MODEL_ID = "Qwen/Qwen3.5-35B-A3B"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-GPTQ-W4A16-G32"

NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map=None,
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

samples_per_dataset = NUM_CALIBRATION_SAMPLES // 2

ds_ultrachat = load_dataset(
    "HuggingFaceH4/ultrachat_200k",
    split=f"train_sft[:{samples_per_dataset}]",
)
ds_nemotron = load_dataset(
    "nvidia/Nemotron-Post-Training-Dataset-v2",
    split=f"chat[:{samples_per_dataset}]",
)

ds_ultrachat = ds_ultrachat.select_columns(["messages"])
ds_nemotron = ds_nemotron.select_columns(["messages"])
ds = concatenate_datasets([ds_ultrachat, ds_nemotron])
ds = ds.shuffle(seed=42)


def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)

recipe = GPTQModifier(
    config_groups={
        "group_0": QuantizationScheme(
            targets=["Linear"],
            weights=QuantizationArgs(
                num_bits=4,
                type=QuantizationType.INT,
                strategy=QuantizationStrategy.GROUP,
                group_size=32,
                symmetric=True,
                dynamic=False,
            ),
        )
    },
    ignore=[
        "lm_head",
        "re:.*mlp.gate$",
        "re:.*mlp.shared_expert_gate$",
        "re:.*linear_attn.*",
        "re:.*visual.*",
    ],
    bypass_divisibility_checks=False,
    block_size=128,
    dampening_frac=0.01,
    actorder="static",
    offload_hessians=False,
)

oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    moe_calibrate_all_experts=True,
    pipeline="sequential",
)

model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)

cache_dir = snapshot_download(MODEL_ID, allow_patterns=["*.json"])

for filename in [
    "vocab.json",
    "preprocessor_config.json",
    "video_preprocessor_config.json",
    "tokenizer_config.json",
]:
    src = os.path.join(cache_dir, filename)
    dst = os.path.join(SAVE_DIR, filename)
    if os.path.exists(src):
        shutil.copyfile(src, dst)
        print(f"Copied: {filename}")
    else:
        print(f"Not Found in Cache: {filename}")   

The issue I'm seeing is that the model often fails to respond or hallucinates data when I request Structured Outputs (JSON). However, when I send the request without forcing a structured format, the responses make much more sense. Interestingly, models like Qwen3VL 30B A3B Instruct handle this correctly.

Tested using VLLM main branch.

Many thanks for your time.

I noticed your code uses from transformers import AutoProcessor, AutoTokenizer, Qwen3_5MoeForConditionalGeneration, which requires transformers>=5.2.0. However, the from llmcompressor import oneshot code indicates that the latest version of llmcompressor depends on transformers>=4.56.1, <=4.57.6.
Could you please advise on how to resolve this transformers version conflict? Which version of transformers are you currently using? Thank you.

@dsikka
Copy link
Collaborator

dsikka commented Mar 11, 2026

Hi @Sehyo I am going to break this PR and land it in smaller pieces as some of this functionality is now out of date.

Thank you for the contribution!

@dsikka dsikka self-assigned this Mar 11, 2026
@phaelon74
Copy link
Contributor

phaelon74 commented Mar 11, 2026

Hi @Sehyo I am going to break this PR and land it in smaller pieces as some of this functionality is now out of date.

Thank you for the contribution!

Apologies for this ask @dsikka , but can you map it out please, as I am having to use my PR to making my Qwen3.5 quants work, so would be nice to know which PRs you will align into implementation, so I know when they land, etc.

@Sehyo
Copy link
Author

Sehyo commented Mar 12, 2026

Hi @Sehyo I am going to break this PR and land it in smaller pieces as some of this functionality is now out of date.

Thank you for the contribution!

Sorry been busy the last days.
Do you want to take this PR over? Or should I assist?

@schoennenbeck
Copy link

Possibly a stupid question: But how does this work without also relaxing the transformers upper bound? The Qwen3.5-MOE architecture has only been supported since transformers 5.2.0 and the current upper bound compatible with llmcompressor is 4.57.6 (or similar)

@phaelon74
Copy link
Contributor

Possibly a stupid question: But how does this work without also relaxing the transformers upper bound? The Qwen3.5-MOE architecture has only been supported since transformers 5.2.0 and the current upper bound compatible with llmcompressor is 4.57.6 (or similar)

At somepoint LLM_Compressor/VLLM will support Transformers 5.2/5.3

For now, Do it in this method:
Install Nightly LLM-Compressor
Install nightly compressed tensors
Install nightly transformers.

It will then work.

@schoennenbeck
Copy link

Possibly a stupid question: But how does this work without also relaxing the transformers upper bound? The Qwen3.5-MOE architecture has only been supported since transformers 5.2.0 and the current upper bound compatible with llmcompressor is 4.57.6 (or similar)

At somepoint LLM_Compressor/VLLM will support Transformers 5.2/5.3

For now, Do it in this method: Install Nightly LLM-Compressor Install nightly compressed tensors Install nightly transformers.

It will then work.

Thanks for the quick response. So this PR "only" adds the architecture support and the full functionality will still depend on other changes.

@phaelon74
Copy link
Contributor

Possibly a stupid question: But how does this work without also relaxing the transformers upper bound? The Qwen3.5-MOE architecture has only been supported since transformers 5.2.0 and the current upper bound compatible with llmcompressor is 4.57.6 (or similar)

At somepoint LLM_Compressor/VLLM will support Transformers 5.2/5.3
For now, Do it in this method: Install Nightly LLM-Compressor Install nightly compressed tensors Install nightly transformers.
It will then work.

Thanks for the quick response. So this PR "only" adds the architecture support and the full functionality will still depend on other changes.

This PR and/or mine, add Qwen3.5 MoE Modeling files, which allow for activating all Experts during calibration. To ensure intelligence persists, you must either use a MASSIVE sampling size for calibration (think 16,000 samples) or you must use a modeling file to activate all experts.

Transformers 5.X is more than just Qwen3.5, it's an amalgamation of a bunch of serious changes, that will take time for VLLM and LLM_Compressor to fix for/against. So that work is still ongoing from those teams. At some point in the future, both will natively support transformers >5.x

@dsikka
Copy link
Collaborator

dsikka commented Mar 17, 2026

FYI - closing in favour of: #2482

@HDCharles HDCharles closed this Mar 17, 2026
dsikka added a commit that referenced this pull request Mar 18, 2026
#2467)

## Summary
- Add Qwen3.5-27B example for NVFP4A16 quantization (`w4a16_fp4/nvfp4`)
- Add Qwen3.5-27B example for MXFP4A16 quantization (`w4a16_fp4/mxfp4`)

Ignore list includes:
- `lm_head` — output head
- `re:.*visual.*` — vision encoder (Qwen3.5 is a VLM)
- `re:.*linear_attn.*` — Gated DeltaNet fused projections incompatible
with microscale formats (ref #2383)
- `re:.*mtp.*` — multi-token prediction modules

> **Note:** Qwen3.5 (`qwen3_5` arch) requires `transformers>=5.x` which
is not yet compatible with llm-compressor. This PR is ready to land once
the transformers version bump is completed.

## Test plan
- [x] Verify quantization runs on Qwen3.5-27B with NVFP4A16 (blocked on
transformers compat)
- [x] Verify quantization runs on Qwen3.5-27B with MXFP4A16 (blocked on
transformers compat)
- [x] Confirm sample generation produces coherent output

---------

Signed-off-by: Ziming <frankziming26@outlook.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation nvfp4 For any PR / issue related to NVFP4 support quality-failed qwen For any PR / issue related to Qwen support ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants