Skip to content

Commit f5f6df8

Browse files
greg-kwasniewski1videodanchik
authored andcommitted
[TRTLLM-10358][feat] Added proper rescaling of FP4 weights (NVIDIA#10378)
Signed-off-by: greg-kwasniewski1 <[email protected]> Signed-off-by: Daniil Kulko <[email protected]>
1 parent 69dc559 commit f5f6df8

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def is_any_lin_op(node: Node) -> bool:
304304
return is_linear_op(node) or is_fake_quantized_linear_op(node)
305305

306306

307+
def is_fp4_op(node: Node) -> bool:
308+
return is_op(
309+
node,
310+
[
311+
torch.ops.auto_deploy.torch_quant_nvfp4_linear,
312+
torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear,
313+
],
314+
)
315+
316+
307317
def is_any_moe_op(node: Node) -> bool:
308318
return is_op(
309319
node,
@@ -733,16 +743,20 @@ def boundary_condition(n):
733743
return subgraph_nodes
734744

735745

736-
def get_weight_shape(
737-
node: Node, dim: Optional[int] = None
738-
) -> Optional[Union[int, Tuple[int, ...]]]:
746+
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
739747
"""Get the shape of the weight node."""
740748
if not is_any_lin_op(node):
741749
return None
750+
s = list(shape(extract_weight_node(node)))
751+
if len(s) == 0:
752+
return None
753+
if is_fp4_op(node):
754+
# FP4 weights are packed as uint8 type with 2 FP4 values per element
755+
s[-1] *= 2
742756
if dim is None:
743-
return shape(extract_weight_node(node))
757+
return s
744758
else:
745-
return shape(extract_weight_node(node))[dim]
759+
return s[dim]
746760

747761

748762
def get_layer_after_linear_node(

0 commit comments

Comments
 (0)