11from dataclasses import dataclass
22import itertools
3+ import math
34import sys
45import torch
56import triton
67# utilities
78from triton_kernels import target_info
89from triton_kernels .numerics import InFlexData , OutFlexData
9- from triton_kernels .routing import GatherIndx , RoutingData , ScatterIndx
10+ from triton_kernels .routing import ExptData , GatherIndx , RoutingData , ScatterIndx
11+ from triton .tools .tensor_descriptor import TensorDescriptor
1012# details
1113from .matmul_ogs_details ._matmul_ogs import _compute_writeback_idx
1214from .matmul_ogs_details ._matmul_ogs import _matmul_ogs
1315from .matmul_ogs_details ._p_matmul_ogs import _p_matmul_ogs , get_per_device_per_stream_alloc_fn
1416from .matmul_ogs_details ._finalize_matmul import _finalize_matmul
15- from .matmul_ogs_details .opt_flags import make_opt_flags
17+ from .matmul_ogs_details .opt_flags import make_opt_flags , OptFlags
1618from .matmul_ogs_details .fast_contiguous import fast_contiguous
1719from .numerics_details .mxfp import SwizzlingType
1820from .specialize import specialize
21+ from typing import Tuple , Optional
1922
2023
2124@dataclass
@@ -95,6 +98,84 @@ def should_upcast_indices(*args):
9598 return any (tensor is not None and can_overflow_int32 (tensor ) for tensor in args )
9699
97100
101+ class TensorDescriptorBuilder :
102+ """Builder for creating different types of tensor descriptors"""
103+
104+ @staticmethod
105+ def create_basic_descriptor (tensor : torch .Tensor , block_shape : Tuple [int , ...],
106+ transpose : bool = False ) -> TensorDescriptor :
107+ """Create a basic tensor descriptor with optional transpose"""
108+ if transpose :
109+ block_shape = block_shape [:- 2 ] + [block_shape [- 1 ], block_shape [- 2 ]]
110+ tensor = tensor .permute (0 , 2 , 1 )
111+ return TensorDescriptor .from_tensor (tensor , block_shape = block_shape )
112+
113+ @staticmethod
114+ def create_weight_descriptor (w_tensor : torch .Tensor , block_k : int , block_n : int ,
115+ transpose : bool ) -> TensorDescriptor :
116+ """Create a tensor descriptor for weight matrix"""
117+ # Two e2m1 packed in a uint8 or a single fp8
118+ W_PACK_DIVISOR = 2 if w_tensor .dtype == torch .uint8 else 1
119+ PACKED_BLOCK_K_W = block_k // W_PACK_DIVISOR
120+ return TensorDescriptorBuilder .create_basic_descriptor (w_tensor , block_shape = [1 , PACKED_BLOCK_K_W , block_n ],
121+ transpose = transpose )
122+
123+ @staticmethod
124+ def create_block_scale_descriptor (mx_tensor : torch .Tensor , block_k : int , block_n : int , K : int , N : int ,
125+ mx_scale_stride_k : int , mx_scale_stride_n : int , n_expts_tot : int , batch_size : int ,
126+ expt_data : Optional [ExptData ], swizzle_mx : bool ,
127+ transpose : bool ) -> TensorDescriptor :
128+ """Create a tensor descriptor for block scale factors"""
129+ MX_PACK_DIVISOR = 32
130+ MX_SCALE_BLOCK_K = block_k // MX_PACK_DIVISOR
131+ PackedK = (K + MX_PACK_DIVISOR - 1 ) // MX_PACK_DIVISOR
132+
133+ if swizzle_mx :
134+ num_expt_x_ncol = (n_expts_tot if expt_data is not None and len (expt_data .block_pid_map ) > 0 else
135+ batch_size ) * ((N + 127 ) // 128 )
136+ return TensorDescriptor (
137+ base = mx_tensor , shape = [1 , num_expt_x_ncol , (PackedK + 3 ) // 4 , 2 , 256 ],
138+ strides = [num_expt_x_ncol * mx_scale_stride_n , mx_scale_stride_n , mx_scale_stride_k , 256 ,
139+ 1 ], block_shape = [1 , block_n // 128 , MX_SCALE_BLOCK_K // 4 , 2 , 256 ])
140+ else :
141+ # Non-optimal SF layout, expect slow transfers
142+ # from global to shmem and from shmem to tmem
143+ return TensorDescriptorBuilder .create_basic_descriptor (mx_tensor ,
144+ block_shape = [1 , MX_SCALE_BLOCK_K ,
145+ block_n ], transpose = transpose )
146+
147+ @staticmethod
148+ def create_input_descriptor_gather (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int ,
149+ block_k : int ) -> TensorDescriptor :
150+ """Create a tensor descriptor for input matrix X via TMA gather"""
151+ x_desc = x_tensor .squeeze ()
152+ assert x_desc .ndim == 2 , "TMA gather descriptor requires 2D input"
153+ INT_MAX = 2147483647
154+ return TensorDescriptor (base = x_desc , shape = [INT_MAX , K ], strides = [x_stride_1 , x_stride_2 ],
155+ block_shape = [1 , block_k ])
156+
157+ @staticmethod
158+ def create_input_descriptor_load (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_m : int ,
159+ block_k : int ) -> TensorDescriptor :
160+ """Create a tensor descriptor for input matrix X via TMA"""
161+ x_desc = x_tensor .squeeze ()
162+ assert x_desc .ndim in [2 , 3 ], "LHS input TMA descriptor builder expects 2D or 3D input"
163+ return TensorDescriptor (base = x_desc , shape = [x_desc .shape [0 ], K ], strides = [x_stride_1 , x_stride_2 ],
164+ block_shape = [block_m , block_k ])
165+
166+ @staticmethod
167+ def create_input_descriptor (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_k : int ,
168+ block_m : int , use_gather_tma : bool , use_load_tma : bool ) -> TensorDescriptor :
169+ """Create a tensor descriptor for input matrix X based on TMA usage"""
170+ if use_gather_tma :
171+ return TensorDescriptorBuilder .create_input_descriptor_gather (x_tensor , K , x_stride_1 , x_stride_2 , block_k )
172+ elif use_load_tma :
173+ return TensorDescriptorBuilder .create_input_descriptor_load (x_tensor , K , x_stride_1 , x_stride_2 , block_m ,
174+ block_k )
175+ else :
176+ return x_tensor
177+
178+
98179# ---------------------
99180# Numerics
100181# ---------------------
@@ -490,7 +571,6 @@ def init_allocation(x, w, precision_config, fused_activation, routing_data, gath
490571 scratchpad ["matmul" ] = ((opt_flags .split_k , x .shape [0 ], M , N ), dtype )
491572 return MatmulAllocation (x .device , output , scratchpad )
492573
493-
494574def apply_allocation (allocation : MatmulAllocation , output ):
495575 ret = dict ()
496576 if output is None :
@@ -504,10 +584,82 @@ def apply_allocation(allocation: MatmulAllocation, output):
504584 }
505585 return ret
506586
587+
507588# -----------------------------------------------------------------------------
508589# Triton Implementation
509590# -----------------------------------------------------------------------------
510591
592+ def _create_tma_descriptors (
593+ x : torch .Tensor ,
594+ x_tensor : torch .Tensor ,
595+ w_tensor : torch .Tensor ,
596+ mx_tensor : Optional [torch .Tensor ],
597+ routing_data : RoutingData ,
598+ mx_ctx : MicroscalingCtx ,
599+ expt_data : ExptData ,
600+ opt_flags : OptFlags ,
601+ batch_size : int ,
602+ K : int ,
603+ N : int ,
604+ mx_scale_stride_k : int ,
605+ mx_scale_stride_n : int ,
606+ USE_GATHER_TMA : bool ,
607+ X_USE_LOAD_TMA : bool ,
608+ w_transpose : bool ,
609+ mx_transpose : bool ,
610+ ) -> Tuple [bool , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
611+ """Create and cache TMA descriptors for tensors."""
612+ use_host_tma_descriptors = opt_flags .is_persistent and target_info .cuda_capability_geq (10 , 0 )
613+
614+ x_desc , w_desc = [None ] * 2
615+ descriptors = []
616+ # The dense case currently uses on device descriptor updates
617+ # so we bail out on using host descriptors in that case
618+ if (use_host_tma_descriptors ):
619+ if USE_GATHER_TMA or X_USE_LOAD_TMA :
620+ x_desc = TensorDescriptorBuilder .create_input_descriptor (
621+ x_tensor , K , x .stride (1 ), x .stride (2 ),
622+ opt_flags .block_k , opt_flags .block_m ,
623+ USE_GATHER_TMA , X_USE_LOAD_TMA
624+ )
625+ descriptors .append (x_desc )
626+ if (expt_data is not None and len (expt_data .block_pid_map ) > 0 ):
627+ w_desc = TensorDescriptorBuilder .create_weight_descriptor (
628+ w_tensor , opt_flags .block_k , opt_flags .block_n , w_transpose
629+ )
630+ is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w_tensor .dtype == torch .uint8 )
631+ if is_microscaled_format :
632+ # Pad the inner shape to 128 for mxfp4 weights
633+ # for mixed precision fp8 x mxfp4 compute
634+ pad = 128
635+ dim_to_pad = - 1 if w_transpose else - 2
636+ old_size = w_desc .shape [dim_to_pad ]
637+ padded_size = math .ceil (old_size / pad ) * pad
638+ if padded_size != old_size :
639+ w_desc .shape = list (w_desc .shape )
640+ w_desc .shape [dim_to_pad ] = padded_size
641+ descriptors .append (w_desc )
642+ # Optional MX scale descriptor
643+ descriptors .append (None )
644+ if mx_tensor is not None :
645+ descriptors [- 1 ] = TensorDescriptorBuilder .create_block_scale_descriptor (
646+ mx_tensor , opt_flags .block_k , opt_flags .block_n , K , N ,
647+ mx_scale_stride_k , mx_scale_stride_n , routing_data .n_expts_tot ,
648+ batch_size ,
649+ expt_data , mx_ctx .swizzle_scale , mx_transpose
650+ )
651+
652+ # TODO: Currently all or none, instead should support a mixture
653+ # of host and device descriptors
654+ if None in descriptors or len (descriptors ) == 0 :
655+ descriptors = [x_tensor , w_tensor , mx_tensor ]
656+ use_host_tma_descriptors = False
657+ if opt_flags .is_persistent :
658+ opt_flags .target_kernel_kwargs ["USE_HOST_TMA_DESCRIPTORS" ] = use_host_tma_descriptors
659+
660+ return use_host_tma_descriptors , * descriptors
661+
662+
511663def matmul_ogs (x , w , bias ,
512664 routing_data : RoutingData | None = None ,
513665 gather_indx : GatherIndx | None = None ,
@@ -601,22 +753,47 @@ def matmul_ogs(x, w, bias,
601753 flex = precision_config .flex_ctx
602754 bias_stride = None if bias is None else bias .stride (0 )
603755 num_indx = None if scatter_indx is None else scatter_indx .src_indx .shape [0 ]
756+
604757 kernels = get_kernels (epilogue .specs , fused_activation .specs )
605758 expt_data = routing_data .expt_data
606759 block_m = opt_flags .block_m
607760 expt_hist = None if expt_data is None else expt_data .hist
608761 expt_hist_sum = None if expt_data is None else expt_data .token_offs_pad [block_m ][- 1 ]
609762 expt_token_offs_raw = None if expt_data is None else expt_data .token_offs_raw
610763 expt_block_pid_map = None if expt_data is None else expt_data .block_pid_map [block_m ]
764+
765+ HAS_TMA_GS = target_info .cuda_capability_geq (10 , 0 )
766+ USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
767+ X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
768+ _ , x_tensor , w_tensor , mx_tensor = _create_tma_descriptors (
769+ x = x ,
770+ x_tensor = flex .lhs_data .reinterpret (x ),
771+ w_tensor = flex .rhs_data .reinterpret (w ),
772+ mx_tensor = mx_ctx .weight_scale ,
773+ routing_data = routing_data ,
774+ mx_ctx = mx_ctx ,
775+ expt_data = expt_data ,
776+ opt_flags = opt_flags ,
777+ batch_size = batch_size ,
778+ K = K ,
779+ N = N ,
780+ mx_scale_stride_k = mx_scale_stride_k ,
781+ mx_scale_stride_n = mx_scale_stride_n ,
782+ USE_GATHER_TMA = USE_GATHER_TMA ,
783+ X_USE_LOAD_TMA = X_USE_LOAD_TMA ,
784+ w_transpose = w .stride (2 ) != 1 ,
785+ mx_transpose = mx_scale_stride_n != 1 ,
786+ )
787+
611788 (kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
612789 flex .out_data .reinterpret (memory ["output" ]),
613790 flex .out_data .reinterpret (out0 ), * out0 .stride (),
614791 * out0_flex ,
615- flex . lhs_data . reinterpret ( x ) , x .stride (0 ), x .stride (1 ), x .stride (2 ),
792+ x_tensor , x .stride (0 ), x .stride (1 ), x .stride (2 ),
616793 flex .lhs_data .scale ,
617- flex . rhs_data . reinterpret ( w ) , w .stride (0 ), w .stride (1 ), w .stride (2 ), w .stride (2 ) != 1 ,
794+ w_tensor , w .stride (0 ), w .stride (1 ), w .stride (2 ), w .stride (2 ) != 1 ,
618795 flex .rhs_data .scale ,
619- mx_ctx . weight_scale , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_scale_stride_n != 1 ,
796+ mx_tensor , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_scale_stride_n != 1 ,
620797 bias , bias_stride ,
621798 x .shape [1 ],
622799 x .shape [1 ] if routing_data .expt_hist is None else None ,
0 commit comments