@@ -144,11 +144,17 @@ def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n
144144 block_shape = [1 , MX_SCALE_BLOCK_K ,
145145 block_n ], transpose = transpose )
146146
147+ @staticmethod
148+ def squeeze_after_dim (x , dim = 2 ):
149+ shape = list (x .shape )
150+ new_shape = [s for s in shape [:dim - 1 ] if s != 1 ] + shape [dim - 1 :]
151+ return x .view (* new_shape )
152+
147153 @staticmethod
148154 def create_input_descriptor_gather (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int ,
149155 block_k : int ) -> TensorDescriptor :
150156 """Create a tensor descriptor for input matrix X via TMA gather"""
151- x_desc = x_tensor . squeeze ( )
157+ x_desc = TensorDescriptorBuilder . squeeze_after_dim ( x_tensor )
152158 assert x_desc .ndim == 2 , "TMA gather descriptor requires 2D input"
153159 INT_MAX = 2147483647
154160 return TensorDescriptor (base = x_desc , shape = [INT_MAX , K ], strides = [x_stride_1 , x_stride_2 ],
@@ -158,7 +164,7 @@ def create_input_descriptor_gather(x_tensor: torch.Tensor, K: int, x_stride_1: i
158164 def create_input_descriptor_load (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_m : int ,
159165 block_k : int ) -> TensorDescriptor :
160166 """Create a tensor descriptor for input matrix X via TMA"""
161- x_desc = x_tensor . squeeze ( )
167+ x_desc = TensorDescriptorBuilder . squeeze_after_dim ( x_tensor )
162168 assert x_desc .ndim in [2 , 3 ], "LHS input TMA descriptor builder expects 2D or 3D input"
163169 return TensorDescriptor (base = x_desc , shape = [x_desc .shape [0 ], K ], strides = [x_stride_1 , x_stride_2 ],
164170 block_shape = [block_m , block_k ])
0 commit comments