Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,21 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -

# Process image and text
inputs = self._processor(images=image, text=prompt, return_tensors="pt")

if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "qwen2_5_vl"
):
inputs = self._qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
)

if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "qwen3_vl_moe"
):
inputs = self._qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
)

# Convert to float32 if needed
if "pixel_values" in inputs:
Expand Down
36 changes: 32 additions & 4 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def __init__(
self.is_qwen2_5_vl = (
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl"
)
self.is_qwen3_vl_moe=(
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe"
)
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
Expand Down Expand Up @@ -256,9 +259,10 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
outputs, position_ids, generation_len = self.run_prefill(
next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1)
)

if self.is_qwen2_5_vl:
_ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id)
elif self.is_qwen3_vl_moe:
_ = self.update_decode_inputs_qwen3_vl_moe(outputs,position_ids,generation_len,decode_batch_id)
else:
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

Expand All @@ -283,6 +287,27 @@ def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len,
self.generation_len[decode_batch_id or slice(None)] = generation_len
return next_token_id

def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_len, decode_batch_id=None):
"""
Updates the decode input with the generated values.
Args:
outputs (dict): The outputs of the model.
position_ids (array): The position IDs.
generation_len (int): The generation length.
decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None.

Returns:
next_token_id (array): The next token ID.
"""
next_token_id = self._fetch_next_token_id(outputs)

# Store the generated values.
self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id
self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1)
self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1)
self.generation_len[decode_batch_id or slice(None)] = generation_len
return next_token_id

def _execute_chunked_prefill(
self,
lang_inputs: Dict[str, np.ndarray],
Expand Down Expand Up @@ -583,12 +608,12 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length)
if self.is_qwen2_5_vl:
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)

if self.is_qwen3_vl_moe:
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)
# Create prompt queue
prompt_queue = deque(vision_prompts)

start = perf_counter()

# Pre-process ALL vision inputs and cache them
logger.info("Pre-processing all vision inputs...")
for batch_id in range(min(self.full_batch_size, len(vision_prompts))):
Expand All @@ -610,7 +635,6 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,

# Reset prompt queue for prefill
prompt_queue = deque(vision_prompts)

self.batch_index = None

# Run prefill for all inputs using cached vision
Expand Down Expand Up @@ -696,6 +720,10 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation
self.update_decode_inputs_qwen2_5_vl(
outputs, position_ids_decode, generation_len_final, decode_batch_id
)
if self.is_qwen3_vl_moe:
self.update_decode_inputs_qwen3_vl_moe(
outputs, position_ids_decode, generation_len_final, decode_batch_id
)
else:
self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id)
else:
Expand Down
2 changes: 0 additions & 2 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,9 @@ def __init__(
**kwargs,
):
# Remove layer_classes if present to avoid duplicate argument
# breakpoint()
kwargs.pop("layers", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

# breakpoint()
layers = []
# If a config is passed, use it to infer the layer types and initialize accordingly
if len(layers) == 0:
Expand Down
5 changes: 3 additions & 2 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
Qwen3VLMoeTextDecoderLayer,
Qwen3VLMoeTextModel,
Qwen3VLMoeTextRMSNorm,
Qwen3VLMoeTextSparseMoeBlock,
Qwen3VLMoeVisionAttention,
Qwen3VLMoeVisionModel,
)
Expand Down Expand Up @@ -398,7 +399,7 @@
QEffQwen3VLMoeTextAttention,
QEffQwen3VLMoeTextDecoderLayer,
QEffQwen3VLMoeTextModel,
# QEffQwen3VLMoeTextSparseMoeBlock,
QEffQwen3VLMoeTextSparseMoeBlock,
QEffQwen3VLMoeVisionAttention,
QEffQwen3VLMoeVisionModel,
)
Expand Down Expand Up @@ -603,7 +604,7 @@ class KVCacheTransform(ModuleMappingTransform):
Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention,
Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel,
Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel,
# Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock,
Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock,
# Grok1
# Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
# Starcoder2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def forward(
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
# kv_seq_len = key_states.shape[-2]
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,23 +711,29 @@ class QEffQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
B, S, H = hidden_states.shape
T = B * S
hidden_states = hidden_states.view(T, H)
router_logits = self.gate(hidden_states) # [T, E]
prob = F.softmax(router_logits, -1, dtype=torch.float)
top_w, top_i = torch.topk(prob, self.top_k, -1)
top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
gate_proj_up_w = self.experts.gate_up_proj.requires_grad_(False)[top_i.flatten()]
down_proj_w = self.experts.down_proj.requires_grad_(False)[top_i.flatten()]

expert_in = hidden_states.unsqueeze(1).expand(-1, self.top_k, -1).contiguous().view(-1, 1, H)
gate_up = torch.bmm(expert_in, gate_proj_up_w)
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
x = hidden_states.view(T, H)

router_logits = self.gate(x)
prob = F.softmax(router_logits, dim=-1, dtype=torch.float)
top_w, top_i = torch.topk(prob, self.top_k, dim=-1)
top_w = top_w / top_w.sum(dim=1, keepdim=True)
top_w = top_w.to(x.dtype)
idx = top_i.reshape(-1)
w_up = self.experts.gate_up_proj.index_select(0, idx)
w_dn = self.experts.down_proj.index_select(0, idx)

xk = x.unsqueeze(1).expand(-1, self.top_k, -1).contiguous()
xk = xk.view(-1, 1, H)
gate_up = torch.bmm(xk, w_up)
I2 = gate_up.size(-1)
half = I2 // 2
gate, up = gate_up[..., :half], gate_up[..., half:]
intermediate = up * self.experts.act_fn(gate)
experts_out = torch.bmm(intermediate, down_proj_w)
experts_out = experts_out.view(B * S, self.top_k, H)
experts_out = experts_out * top_w.unsqueeze(-1)
experts_out = experts_out.sum(dim=1)
return experts_out.view(B, S, H), router_logits
experts_out = torch.bmm(intermediate, w_dn)
experts_out = experts_out.view(T, self.top_k, H) * top_w.unsqueeze(-1)
experts_out = experts_out.sum(dim=1).view(B, S, H)

return experts_out, router_logits


class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
Expand All @@ -737,44 +743,6 @@ def get_qeff_vision_encoder(self):
def get_qeff_language_decoder(self):
return QEffQwen3VLDecoderWrapper(self)

# def forward(
# self,
# input_ids,
# position_ids,
# past_key_values,
# pixel_values:Optional[torch.FloatTensor] = None,
# image_idx:Optional[torch.LongTensor] = None,
# comp_ctx_lengths: Optional[List[int]] = None,
# batch_index: Optional[torch.LongTensor] = None,
# image_grid_thw=None,
# ):
# image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0]
# bs = image_grid_thw.shape[0]
# split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs)

# inputs_embeds = self.model.get_input_embeddings()(input_ids)
# B, N, C = inputs_embeds.shape
# selected = input_ids == self.model.config.image_token_id
# indices1 = selected.to(torch.int64).cumsum(1) - 1
# indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
# indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
# image_features_expanded = image_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
# image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
# inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds)
# outputs = self.language_model(
# inputs_embeds=inputs_embeds,
# position_ids=position_ids,
# past_key_values=past_key_values,
# comp_ctx_lengths=comp_ctx_lengths,
# batch_index=batch_index,
# use_cache=True,
# )
# logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
# hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
# logits = self.lm_head(hidden_states)
# image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
# return logits, image_embeds, image_idx, outputs.past_key_values

def get_dummy_inputs(
self,
comp_ctx_lengths: Optional[List[int]] = None,
Expand Down Expand Up @@ -1036,7 +1004,7 @@ def get_onnx_dynamic_axes(
lang_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {1: "batch_size", 2: "seq_len"},
"vision_embeds": {0: "batch_size", 1: "vision_size"},
"vision_embeds": {0: "vision_batch_size", 1: "vision_size"},
}

for i in range(num_layers):
Expand Down Expand Up @@ -1102,6 +1070,7 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=
inputs["position_ids"] = F.pad(
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
)
inputs.pop("image_grid_thw", None)
return inputs

def get_inputs_info(self):
Expand Down
11 changes: 8 additions & 3 deletions examples/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
model_id, attn_implementation="eager", kv_offload=True, config=config
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

Check failure on line 21 in examples/qwen3_vl.py

View workflow job for this annotation

GitHub Actions / lint

ruff (F821)

examples/qwen3_vl.py:21:13: F821 Undefined name `transformers`
processor = AutoProcessor.from_pretrained(model_id)
### use skip_vision=Ture, if want to run only text, else false ###
skip_vision = False
Expand Down Expand Up @@ -60,7 +61,7 @@
return_tensors="pt",
)
inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size)
streamer = TextStreamer(processor.tokenizer)
streamer = TextStreamer(tokenizer)
output = qeff_model.generate(inputs=inputs, generation_len=100)
print(output.generated_ids)
print(processor.tokenizer.batch_decode(output.generated_ids))
Expand All @@ -77,6 +78,8 @@
num_devices=4,
height=354,
width=536,
# height=1024,
# width=1024,
mxfp6_matmul=True,
mxint8_kv_cache=True,
aic_enable_depth_first=True,
Expand All @@ -85,7 +88,9 @@

### IMAGE + TEXT ###
image_url = "https://picsum.photos/id/237/536/354"
# image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
# image_url = (
# "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
# )

image = Image.open(requests.get(image_url, stream=True).raw)

Expand Down Expand Up @@ -122,7 +127,7 @@
return_tensors="pt",
)
inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size)
streamer = TextStreamer(processor.tokenizer)
streamer = TextStreamer(tokenizer)
output = qeff_model.generate(inputs=inputs, generation_len=100)
print(output.generated_ids)
print(processor.tokenizer.batch_decode(output.generated_ids))
Expand Down
Loading
Loading