66# -----------------------------------------------------------------------------
77
88import copy
9- from typing import List , Optional , Tuple , Union
9+ from typing import List , Optional , Tuple , Type , Union
1010
1111import torch
1212from torch import nn
@@ -589,6 +589,15 @@ def __init__(self, model):
589589 self .model = model
590590 self .model .vision_model = self .model .vision_tower
591591
592+ def get_submodules_for_export (self ) -> Type [nn .Module ]:
593+ """
594+ Return the set of class used as the repeated layer across the model for subfunction extraction.
595+ Notes:
596+ This method should return the *class object* (not an instance).
597+ Downstream code can use this to find/build subfunctions for repeated blocks.
598+ """
599+ return {self .model .vision_tower .vision_model .encoder .layers [0 ].__class__ }
600+
592601 def forward (self , pixel_values ):
593602 image_features = self .model .get_image_features (pixel_values = pixel_values )
594603 return image_features
@@ -602,6 +611,15 @@ def __init__(self, model):
602611 self .config = self .model .config
603612 self .lm_head = self .model .lm_head
604613
614+ def get_submodules_for_export (self ) -> Type [nn .Module ]:
615+ """
616+ Return the set of class used as the repeated layer across the model for subfunction extraction.
617+ Notes:
618+ This method should return the *class object* (not an instance).
619+ Downstream code can use this to find/build subfunctions for repeated blocks.
620+ """
621+ return {QEffGemma3DecoderLayer }
622+
605623 def forward (
606624 self ,
607625 input_ids ,
0 commit comments