@@ -38,23 +38,11 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorIn
3838 line_sizes : & MatmulLineSizes ,
3939 ) -> Self :: RuntimeArg < ' a , R > {
4040 TensorInputsLaunch :: new (
41- lhs. data ( )
42- . try_as_tensor_arg ( line_sizes. lhs )
43- . expect ( "valid vec lhs" ) ,
44- lhs. scale ( )
45- . map ( |it| it. try_as_tensor_arg ( 1 ) . expect ( "vec=1" ) )
46- . into ( ) ,
47- rhs. data ( )
48- . try_as_tensor_arg ( line_sizes. rhs )
49- . expect ( "valid vec rhs" ) ,
50- rhs. scale ( )
51- . map ( |it| it. try_as_tensor_arg ( 1 ) . expect ( "vec=1" ) )
52- . into ( ) ,
53- bias. map ( |it| {
54- it. try_as_tensor_arg ( line_sizes. out )
55- . expect ( "valid vec out" )
56- } )
57- . into ( ) ,
41+ lhs. data ( ) . as_tensor_arg ( line_sizes. lhs ) ,
42+ lhs. scale ( ) . map ( |it| it. as_tensor_arg ( 1 ) ) . into ( ) ,
43+ rhs. data ( ) . as_tensor_arg ( line_sizes. rhs ) ,
44+ rhs. scale ( ) . map ( |it| it. as_tensor_arg ( 1 ) ) . into ( ) ,
45+ bias. map ( |it| it. as_tensor_arg ( line_sizes. out ) ) . into ( ) ,
5846 )
5947 }
6048}
@@ -116,9 +104,7 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
116104 channels_per_pixel : tile_size_k,
117105 pixels_per_column : stage_m,
118106 } ,
119- lhs. data ( )
120- . try_as_tensor_arg ( line_sizes. lhs )
121- . expect ( "valid vec lhs" ) ,
107+ lhs. data ( ) . as_tensor_arg ( line_sizes. lhs ) ,
122108 lhs_elem,
123109 )
124110 . with_elem_stride ( elem_stride)
@@ -128,15 +114,12 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
128114 TensorMapFormat :: Tiled {
129115 tile_size : stage_size_rhs,
130116 } ,
131- rhs. data ( ) . try_as_tensor_arg ( 1 ) . expect ( "vec=1" ) ,
117+ rhs. data ( ) . as_tensor_arg ( 1 ) ,
132118 Rhs :: as_type_native_unchecked ( ) ,
133119 )
134120 . with_prefetch ( prefetch_rhs) ;
135121
136- let bias = bias. map ( |it| {
137- it. try_as_tensor_arg ( line_sizes. out )
138- . expect ( "valid vec out" )
139- } ) ;
122+ let bias = bias. map ( |it| it. as_tensor_arg ( line_sizes. out ) ) ;
140123
141124 // TODO: Think about how to handle scales with TMA
142125 TensorMapInputsLaunch :: new ( lhs, rhs, bias. into ( ) )
0 commit comments