@@ -337,6 +337,17 @@ def __init__(self, **kwargs):
337337 self .src_partitions = None
338338 self .linear_modules = [torch .nn .functional .linear , torch .nn .Linear ]
339339
340+ def _get_weight_deps (
341+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
342+ ) -> Tuple [bool , List [torch .fx .Node ]]:
343+ # TODO(maxren, T210537195):
344+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
345+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
346+ # do not partition the weight node
347+ return (True , [])
348+
349+ return super ()._get_weight_deps (node , ep , precision )
350+
340351 def get_deps (
341352 self ,
342353 node : torch .fx .Node ,
@@ -436,6 +447,16 @@ def __init__(self, **kwargs):
436447 self .weight_idx = 1
437448 self .act_idx = 0
438449
450+ def _get_weight_deps (
451+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
452+ ) -> Tuple [bool , List [torch .fx .Node ]]:
453+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
454+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
455+ # do not partition the weight node
456+ return (True , [])
457+
458+ return super ()._get_weight_deps (node , ep , precision )
459+
439460 def supported_precision_types (self ):
440461 return [
441462 ConfigPrecisionType .FP32 ,
0 commit comments