Skip to content

Commit eb74758

Browse files
Added support of subfunction for VLMs (quic#699)
Signed-off-by: Abhishek Kumar Singh <sabhis@qti.qualcomm.com> Signed-off-by: abhishek-singh591 <sabhis@qti.qualcomm.com> Signed-off-by: Abhishek kumar singh <sabhis@qti.qualcomm.com>
1 parent 32f30c0 commit eb74758

38 files changed

+604
-74
lines changed

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
"""PyTorch Codegen model."""
99

10-
from typing import Optional, Tuple, Union
10+
from typing import Optional, Tuple, Type, Union
1111

1212
import torch
1313
from torch import nn
@@ -296,6 +296,15 @@ class QEffCodeGenForCausalLM(CodeGenForCausalLM):
296296
- update the hidden_states, and fix for onnx model
297297
"""
298298

299+
def get_submodules_for_export(self) -> Type[nn.Module]:
300+
"""
301+
Return the set of class used as the repeated layer across the model for subfunction extraction.
302+
Notes:
303+
This method should return the *class object* (not an instance).
304+
Downstream code can use this to find/build subfunctions for repeated blocks.
305+
"""
306+
return {QEffCodeGenBlock}
307+
299308
def forward(
300309
self,
301310
input_ids: Optional[torch.LongTensor] = None,

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
"""PyTorch Falcon model."""
99

1010
import math
11-
from typing import Optional, Tuple, Union
11+
from typing import Optional, Tuple, Type, Union
1212

1313
import torch
14+
import torch.nn as nn
1415
import torch.utils.checkpoint
1516
from torch.nn import functional as F
1617
from transformers.cache_utils import Cache
@@ -353,6 +354,15 @@ class QEffFalconForCausalLM(FalconForCausalLM):
353354
- update the hidden_states, and fix for onnx model
354355
"""
355356

357+
def get_submodules_for_export(self) -> Type[nn.Module]:
358+
"""
359+
Return the set of class used as the repeated layer across the model for subfunction extraction.
360+
Notes:
361+
This method should return the *class object* (not an instance).
362+
Downstream code can use this to find/build subfunctions for repeated blocks.
363+
"""
364+
return {QEffFalconDecoderLayer}
365+
356366
def forward(
357367
self,
358368
input_ids: torch.LongTensor = None,

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import List, Optional, Tuple, Union
8+
from typing import List, Optional, Tuple, Type, Union
99

1010
import torch
1111
from torch import nn
@@ -336,6 +336,15 @@ class QEffGemmaForCausalLM(GemmaForCausalLM):
336336
- add new args cache idx for the kv retention
337337
"""
338338

339+
def get_submodules_for_export(self) -> Type[nn.Module]:
340+
"""
341+
Return the set of class used as the repeated layer across the model for subfunction extraction.
342+
Notes:
343+
This method should return the *class object* (not an instance).
344+
Downstream code can use this to find/build subfunctions for repeated blocks.
345+
"""
346+
return {QEffGemmaDecoderLayer}
347+
339348
def forward(
340349
self,
341350
input_ids: torch.LongTensor = None,

QEfficient/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import Callable, List, Optional, Tuple, Union
8+
from typing import Callable, List, Optional, Tuple, Type, Union
99

1010
import torch
1111
from torch import nn
@@ -388,6 +388,15 @@ class QEffGemma2ForCausalLM(Gemma2ForCausalLM, GenerationMixin):
388388
- add new args cache idx for the kv retention
389389
"""
390390

391+
def get_submodules_for_export(self) -> Type[nn.Module]:
392+
"""
393+
Return the set of class used as the repeated layer across the model for subfunction extraction.
394+
Notes:
395+
This method should return the *class object* (not an instance).
396+
Downstream code can use this to find/build subfunctions for repeated blocks.
397+
"""
398+
return {QEffGemma2DecoderLayer}
399+
391400
def forward(
392401
self,
393402
input_ids: torch.LongTensor = None,

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
import copy
9-
from typing import List, Optional, Tuple, Union
9+
from typing import List, Optional, Tuple, Type, Union
1010

1111
import torch
1212
from torch import nn
@@ -589,6 +589,15 @@ def __init__(self, model):
589589
self.model = model
590590
self.model.vision_model = self.model.vision_tower
591591

592+
def get_submodules_for_export(self) -> Type[nn.Module]:
593+
"""
594+
Return the set of class used as the repeated layer across the model for subfunction extraction.
595+
Notes:
596+
This method should return the *class object* (not an instance).
597+
Downstream code can use this to find/build subfunctions for repeated blocks.
598+
"""
599+
return {self.model.vision_tower.vision_model.encoder.layers[0].__class__}
600+
592601
def forward(self, pixel_values):
593602
image_features = self.model.get_image_features(pixel_values=pixel_values)
594603
return image_features
@@ -602,6 +611,15 @@ def __init__(self, model):
602611
self.config = self.model.config
603612
self.lm_head = self.model.lm_head
604613

614+
def get_submodules_for_export(self) -> Type[nn.Module]:
615+
"""
616+
Return the set of class used as the repeated layer across the model for subfunction extraction.
617+
Notes:
618+
This method should return the *class object* (not an instance).
619+
Downstream code can use this to find/build subfunctions for repeated blocks.
620+
"""
621+
return {QEffGemma3DecoderLayer}
622+
605623
def forward(
606624
self,
607625
input_ids,

QEfficient/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import Callable, Optional, Tuple, Union
8+
from typing import Callable, Optional, Tuple, Type, Union
99

1010
import torch
1111
from torch import nn
@@ -397,6 +397,15 @@ class QEffGPT2LMHeadModel(GPT2LMHeadModel):
397397
- add new args position idx for the cache_kwargs for kv retention
398398
"""
399399

400+
def get_submodules_for_export(self) -> Type[nn.Module]:
401+
"""
402+
Return the set of class used as the repeated layer across the model for subfunction extraction.
403+
Notes:
404+
This method should return the *class object* (not an instance).
405+
Downstream code can use this to find/build subfunctions for repeated blocks.
406+
"""
407+
return {QEffGPT2Block}
408+
400409
def forward(
401410
self,
402411
input_ids: Optional[torch.LongTensor] = None,

QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
"""PyTorch GPTBigCode model."""
99

10-
from typing import Optional, Tuple, Union
10+
from typing import Optional, Tuple, Type, Union
1111

1212
import torch
1313
import torch.utils.checkpoint
@@ -378,6 +378,15 @@ def forward(
378378

379379

380380
class QEffGPTBigCodeForCausalLM(GPTBigCodeForCausalLM):
381+
def get_submodules_for_export(self) -> Type[nn.Module]:
382+
"""
383+
Return the set of class used as the repeated layer across the model for subfunction extraction.
384+
Notes:
385+
This method should return the *class object* (not an instance).
386+
Downstream code can use this to find/build subfunctions for repeated blocks.
387+
"""
388+
return {QEffGPTBigCodeBlock}
389+
381390
def forward(
382391
self,
383392
input_ids: Optional[torch.Tensor] = None,

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# -----------------------------------------------------------------------------
77
import math
88
import os
9-
from typing import Callable, Optional, Union
9+
from typing import Callable, Optional, Type, Union
1010

1111
import torch
1212
from torch import nn
@@ -1205,6 +1205,16 @@ def forward(
12051205

12061206

12071207
class QEffGptOssForCausalLM(GptOssForCausalLM):
1208+
def get_submodules_for_export(self) -> Type[nn.Module]:
1209+
"""
1210+
Return the set of class used as the repeated layer across the model for subfunction extraction.
1211+
1212+
Notes:
1213+
This method should return the *class object* (not an instance).
1214+
Downstream code can use this to find/build subfunctions for repeated blocks.
1215+
"""
1216+
return {QEffGptOssDecoderLayer}
1217+
12081218
def forward(
12091219
self,
12101220
input_ids: Optional[torch.LongTensor] = None,

QEfficient/transformers/models/gptj/modeling_gptj.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
"""PyTorch GPT-J model."""
99

10-
from typing import Optional, Tuple, Union
10+
from typing import Optional, Tuple, Type, Union
1111

1212
import torch
1313
from torch import nn
@@ -318,6 +318,15 @@ class QEffGPTJForCausalLM(GPTJForCausalLM):
318318
- update the hidden_states, and fix for onnx model
319319
"""
320320

321+
def get_submodules_for_export(self) -> Type[nn.Module]:
322+
"""
323+
Return the set of class used as the repeated layer across the model for subfunction extraction.
324+
Notes:
325+
This method should return the *class object* (not an instance).
326+
Downstream code can use this to find/build subfunctions for repeated blocks.
327+
"""
328+
return {QEffGPTJBlock}
329+
321330
def forward(
322331
self,
323332
input_ids: Optional[torch.LongTensor] = None,

QEfficient/transformers/models/granite/modeling_granite.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import Callable, List, Optional, Tuple, Union
8+
from typing import Callable, List, Optional, Tuple, Type, Union
99

1010
import torch
1111
from torch import nn
@@ -347,6 +347,15 @@ class QEffGraniteForCausalLM(GraniteForCausalLM):
347347
Copied from GraniteForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py
348348
"""
349349

350+
def get_submodules_for_export(self) -> Type[nn.Module]:
351+
"""
352+
Return the set of class used as the repeated layer across the model for subfunction extraction.
353+
Notes:
354+
This method should return the *class object* (not an instance).
355+
Downstream code can use this to find/build subfunctions for repeated blocks.
356+
"""
357+
return {QEffGraniteDecoderLayer}
358+
350359
def forward(
351360
self,
352361
input_ids: torch.LongTensor = None,

0 commit comments

Comments
 (0)